improve start state validity check
This commit is contained in:
@@ -13,7 +13,7 @@ its affiliates is strictly prohibited.
|
|||||||
## Latest Commit
|
## Latest Commit
|
||||||
|
|
||||||
### New Features
|
### New Features
|
||||||
- Added validity of start state check for motion_gen plan calls for single queries.
|
- Add start state checks for world collision, self-collision, and joint limits.
|
||||||
|
|
||||||
### BugFixes & Misc.
|
### BugFixes & Misc.
|
||||||
- Fix bug in evaluator to account for dof maximum acceleration and jerk.
|
- Fix bug in evaluator to account for dof maximum acceleration and jerk.
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ import math
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
# Third Party
|
# Third Party
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -1006,9 +1006,17 @@ class MotionGenStatus(Enum):
|
|||||||
#: pose/ joint target in joint limits.
|
#: pose/ joint target in joint limits.
|
||||||
INVALID_QUERY = "Invalid Query"
|
INVALID_QUERY = "Invalid Query"
|
||||||
|
|
||||||
#: Invalid query was given. The start state is either out of joint limits, in collision with
|
#: Invalid start state was given. Unknown reason.
|
||||||
#: world, or in self-collision.
|
INVALID_START_STATE_UNKNOWN_ISSUE = "Invalid Start State, unknown issue"
|
||||||
INVALID_START = "Invalid Start"
|
|
||||||
|
#: Invalid start state was given. The start state is in world collision.
|
||||||
|
INVALID_START_STATE_WORLD_COLLISION = "Start state is colliding with world"
|
||||||
|
|
||||||
|
#: Invalid start state was given. The start state is in self-collision.
|
||||||
|
INVALID_START_STATE_SELF_COLLISION = "Start state is in self-collision"
|
||||||
|
|
||||||
|
#: Invalid start state was given. The start state is out of joint limits.
|
||||||
|
INVALID_START_STATE_JOINT_LIMITS = "Start state is out of joint limits"
|
||||||
|
|
||||||
#: Motion generation query was successful.
|
#: Motion generation query was successful.
|
||||||
SUCCESS = "Success"
|
SUCCESS = "Success"
|
||||||
@@ -1986,15 +1994,14 @@ class MotionGen(MotionGenConfig):
|
|||||||
n_goalset=1,
|
n_goalset=1,
|
||||||
)
|
)
|
||||||
force_graph = plan_config.enable_graph
|
force_graph = plan_config.enable_graph
|
||||||
|
valid_query = True
|
||||||
if plan_config.check_start_validity:
|
if plan_config.check_start_validity:
|
||||||
# check if start state is collision-free and within limits:
|
valid_query, status = self.check_start_state(start_state)
|
||||||
valid_query = self.ik_solver.check_valid(start_state.position).item()
|
|
||||||
if not valid_query:
|
if not valid_query:
|
||||||
result = MotionGenResult(
|
result = MotionGenResult(
|
||||||
success=torch.as_tensor([False], device=self.tensor_args.device),
|
success=torch.as_tensor([False], device=self.tensor_args.device),
|
||||||
valid_query=valid_query,
|
valid_query=valid_query,
|
||||||
status=MotionGenStatus.INVALID_START,
|
status=status,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -2566,6 +2573,76 @@ class MotionGen(MotionGenConfig):
|
|||||||
robot_cfg = RobotConfig.from_dict(robot_config_dict, self.tensor_args)
|
robot_cfg = RobotConfig.from_dict(robot_config_dict, self.tensor_args)
|
||||||
self.kinematics.update_kinematics_config(robot_cfg.kinematics.kinematics_config)
|
self.kinematics.update_kinematics_config(robot_cfg.kinematics.kinematics_config)
|
||||||
|
|
||||||
|
def check_start_state(
|
||||||
|
self, start_state: JointState
|
||||||
|
) -> Tuple[bool, Union[None, MotionGenStatus]]:
|
||||||
|
"""Check if the start state is valid for motion generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_state: Start joint state of the robot.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, MotionGenStatus]: Tuple containing True if the start state is valid and
|
||||||
|
the status of the start state.
|
||||||
|
"""
|
||||||
|
joint_position = start_state.position
|
||||||
|
|
||||||
|
if len(joint_position.shape) == 1:
|
||||||
|
joint_position = joint_position.unsqueeze(0)
|
||||||
|
if len(joint_position.shape) > 2:
|
||||||
|
log_error("joint_position should be of shape (batch, dof)")
|
||||||
|
joint_position = joint_position.unsqueeze(1)
|
||||||
|
metrics = self.rollout_fn.rollout_constraint(
|
||||||
|
joint_position,
|
||||||
|
use_batch_env=False,
|
||||||
|
)
|
||||||
|
valid_query = metrics.feasible.squeeze(1).item()
|
||||||
|
status = None
|
||||||
|
if not valid_query:
|
||||||
|
self.rollout_fn.primitive_collision_constraint.disable_cost()
|
||||||
|
self.rollout_fn.robot_self_collision_constraint.disable_cost()
|
||||||
|
within_joint_limits = (
|
||||||
|
self.rollout_fn.rollout_constraint(
|
||||||
|
joint_position,
|
||||||
|
use_batch_env=False,
|
||||||
|
)
|
||||||
|
.feasible.squeeze(1)
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rollout_fn.primitive_collision_constraint.enable_cost()
|
||||||
|
|
||||||
|
if not within_joint_limits:
|
||||||
|
self.rollout_fn.robot_self_collision_constraint.enable_cost()
|
||||||
|
return valid_query, MotionGenStatus.INVALID_START_STATE_JOINT_LIMITS
|
||||||
|
|
||||||
|
self.rollout_fn.primitive_collision_constraint.enable_cost()
|
||||||
|
world_collision_free = (
|
||||||
|
self.rollout_fn.rollout_constraint(
|
||||||
|
joint_position,
|
||||||
|
use_batch_env=False,
|
||||||
|
)
|
||||||
|
.feasible.squeeze(1)
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
if not world_collision_free:
|
||||||
|
return valid_query, MotionGenStatus.INVALID_START_STATE_WORLD_COLLISION
|
||||||
|
|
||||||
|
self.rollout_fn.robot_self_collision_constraint.enable_cost()
|
||||||
|
self_collision_free = (
|
||||||
|
self.rollout_fn.rollout_constraint(
|
||||||
|
joint_position,
|
||||||
|
use_batch_env=False,
|
||||||
|
)
|
||||||
|
.feasible.squeeze(1)
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self_collision_free:
|
||||||
|
return valid_query, MotionGenStatus.INVALID_START_STATE_SELF_COLLISION
|
||||||
|
status = MotionGenStatus.INVALID_START_STATE_UNKNOWN_ISSUE
|
||||||
|
return (valid_query, status)
|
||||||
|
|
||||||
@profiler.record_function("motion_gen/ik")
|
@profiler.record_function("motion_gen/ik")
|
||||||
def _solve_ik_from_solve_state(
|
def _solve_ik_from_solve_state(
|
||||||
self,
|
self,
|
||||||
@@ -2780,13 +2857,12 @@ class MotionGen(MotionGenConfig):
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
valid_query = True
|
valid_query = True
|
||||||
if plan_config.check_start_validity:
|
if plan_config.check_start_validity:
|
||||||
# check if start state is collision-free and within limits:
|
valid_query, status = self.check_start_state(start_state)
|
||||||
valid_query = self.ik_solver.check_valid(start_state.position).item()
|
|
||||||
if not valid_query:
|
if not valid_query:
|
||||||
result = MotionGenResult(
|
result = MotionGenResult(
|
||||||
success=torch.as_tensor([False], device=self.tensor_args.device),
|
success=torch.as_tensor([False], device=self.tensor_args.device),
|
||||||
valid_query=valid_query,
|
valid_query=valid_query,
|
||||||
status=MotionGenStatus.INVALID_START,
|
status=status,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
if plan_config.pose_cost_metric is not None:
|
if plan_config.pose_cost_metric is not None:
|
||||||
|
|||||||
@@ -335,7 +335,16 @@ def test_motion_gen_single_js(motion_gen_str, enable_graph, request):
|
|||||||
assert torch.norm(goal_state.position - reached_state.position) < 0.05
|
assert torch.norm(goal_state.position - reached_state.position) < 0.05
|
||||||
|
|
||||||
|
|
||||||
def test_motion_gen_single_js_invalid_start(motion_gen):
|
@pytest.mark.parametrize(
|
||||||
|
"motion_gen_str,invalid_status",
|
||||||
|
[
|
||||||
|
("motion_gen", MotionGenStatus.INVALID_START_STATE_JOINT_LIMITS),
|
||||||
|
("motion_gen", MotionGenStatus.INVALID_START_STATE_SELF_COLLISION),
|
||||||
|
("motion_gen", MotionGenStatus.INVALID_START_STATE_WORLD_COLLISION),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_motion_gen_single_js_invalid_start(motion_gen_str, invalid_status, request):
|
||||||
|
motion_gen = request.getfixturevalue(motion_gen_str)
|
||||||
|
|
||||||
motion_gen.reset()
|
motion_gen.reset()
|
||||||
|
|
||||||
@@ -347,18 +356,32 @@ def test_motion_gen_single_js_invalid_start(motion_gen):
|
|||||||
|
|
||||||
goal_state = start_state.clone()
|
goal_state = start_state.clone()
|
||||||
goal_state.position -= 0.3
|
goal_state.position -= 0.3
|
||||||
start_state.position[0] += 10.0
|
if invalid_status == MotionGenStatus.INVALID_START_STATE_JOINT_LIMITS:
|
||||||
|
start_state.position[0, 0] += 10.0
|
||||||
|
if invalid_status == MotionGenStatus.INVALID_START_STATE_SELF_COLLISION:
|
||||||
|
start_state.position[0, 3] = -3.0
|
||||||
|
if invalid_status == MotionGenStatus.INVALID_START_STATE_WORLD_COLLISION:
|
||||||
|
start_state.position[0, 1] = 1.7
|
||||||
result = motion_gen.plan_single_js(start_state, goal_state, m_config)
|
result = motion_gen.plan_single_js(start_state, goal_state, m_config)
|
||||||
|
|
||||||
assert torch.count_nonzero(result.success) == 0
|
assert torch.count_nonzero(result.success) == 0
|
||||||
|
|
||||||
assert result.valid_query == False
|
assert result.valid_query == False
|
||||||
|
|
||||||
assert result.status == MotionGenStatus.INVALID_START
|
assert result.status == invalid_status
|
||||||
|
|
||||||
|
|
||||||
def test_motion_gen_single_invalid(motion_gen):
|
@pytest.mark.parametrize(
|
||||||
|
"motion_gen_str,invalid_status",
|
||||||
|
[
|
||||||
|
("motion_gen", MotionGenStatus.INVALID_START_STATE_JOINT_LIMITS),
|
||||||
|
("motion_gen", MotionGenStatus.INVALID_START_STATE_SELF_COLLISION),
|
||||||
|
("motion_gen", MotionGenStatus.INVALID_START_STATE_WORLD_COLLISION),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_motion_gen_single_invalid(motion_gen_str, invalid_status, request):
|
||||||
|
motion_gen = request.getfixturevalue(motion_gen_str)
|
||||||
|
|
||||||
motion_gen.reset()
|
motion_gen.reset()
|
||||||
|
|
||||||
retract_cfg = motion_gen.get_retract_config()
|
retract_cfg = motion_gen.get_retract_config()
|
||||||
@@ -368,7 +391,12 @@ def test_motion_gen_single_invalid(motion_gen):
|
|||||||
goal_pose = Pose(state.ee_pos_seq, quaternion=state.ee_quat_seq)
|
goal_pose = Pose(state.ee_pos_seq, quaternion=state.ee_quat_seq)
|
||||||
|
|
||||||
start_state = JointState.from_position(retract_cfg.view(1, -1) + 0.3)
|
start_state = JointState.from_position(retract_cfg.view(1, -1) + 0.3)
|
||||||
start_state.position[..., 1] = 1.7
|
if invalid_status == MotionGenStatus.INVALID_START_STATE_JOINT_LIMITS:
|
||||||
|
start_state.position[0, 0] += 10.0
|
||||||
|
if invalid_status == MotionGenStatus.INVALID_START_STATE_SELF_COLLISION:
|
||||||
|
start_state.position[0, 3] = -3.0
|
||||||
|
if invalid_status == MotionGenStatus.INVALID_START_STATE_WORLD_COLLISION:
|
||||||
|
start_state.position[0, 1] = 1.7
|
||||||
|
|
||||||
m_config = MotionGenPlanConfig(False, True, max_attempts=1)
|
m_config = MotionGenPlanConfig(False, True, max_attempts=1)
|
||||||
|
|
||||||
@@ -379,4 +407,4 @@ def test_motion_gen_single_invalid(motion_gen):
|
|||||||
|
|
||||||
assert result.valid_query == False
|
assert result.valid_query == False
|
||||||
|
|
||||||
assert result.status == MotionGenStatus.INVALID_START
|
assert result.status == invalid_status
|
||||||
|
|||||||
Reference in New Issue
Block a user