improve start state validity check

This commit is contained in:
Balakumar Sundaralingam
2024-05-15 12:09:23 -07:00
parent 911da8cb24
commit 3bfed9d773
3 changed files with 123 additions and 19 deletions

View File

@@ -13,7 +13,7 @@ its affiliates is strictly prohibited.
## Latest Commit ## Latest Commit
### New Features ### New Features
- Added validity of start state check for motion_gen plan calls for single queries. - Add start state checks for world collision, self-collision, and joint limits.
### BugFixes & Misc. ### BugFixes & Misc.
- Fix bug in evaluator to account for dof maximum acceleration and jerk. - Fix bug in evaluator to account for dof maximum acceleration and jerk.

View File

@@ -41,7 +41,7 @@ import math
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Tuple, Union
# Third Party # Third Party
import numpy as np import numpy as np
@@ -1006,9 +1006,17 @@ class MotionGenStatus(Enum):
#: pose/ joint target in joint limits. #: pose/ joint target in joint limits.
INVALID_QUERY = "Invalid Query" INVALID_QUERY = "Invalid Query"
#: Invalid query was given. The start state is either out of joint limits, in collision with #: Invalid start state was given. Unknown reason.
#: world, or in self-collision. INVALID_START_STATE_UNKNOWN_ISSUE = "Invalid Start State, unknown issue"
INVALID_START = "Invalid Start"
#: Invalid start state was given. The start state is in world collision.
INVALID_START_STATE_WORLD_COLLISION = "Start state is colliding with world"
#: Invalid start state was given. The start state is in self-collision.
INVALID_START_STATE_SELF_COLLISION = "Start state is in self-collision"
#: Invalid start state was given. The start state is out of joint limits.
INVALID_START_STATE_JOINT_LIMITS = "Start state is out of joint limits"
#: Motion generation query was successful. #: Motion generation query was successful.
SUCCESS = "Success" SUCCESS = "Success"
@@ -1986,15 +1994,14 @@ class MotionGen(MotionGenConfig):
n_goalset=1, n_goalset=1,
) )
force_graph = plan_config.enable_graph force_graph = plan_config.enable_graph
valid_query = True
if plan_config.check_start_validity: if plan_config.check_start_validity:
# check if start state is collision-free and within limits: valid_query, status = self.check_start_state(start_state)
valid_query = self.ik_solver.check_valid(start_state.position).item()
if not valid_query: if not valid_query:
result = MotionGenResult( result = MotionGenResult(
success=torch.as_tensor([False], device=self.tensor_args.device), success=torch.as_tensor([False], device=self.tensor_args.device),
valid_query=valid_query, valid_query=valid_query,
status=MotionGenStatus.INVALID_START, status=status,
) )
return result return result
@@ -2566,6 +2573,76 @@ class MotionGen(MotionGenConfig):
robot_cfg = RobotConfig.from_dict(robot_config_dict, self.tensor_args) robot_cfg = RobotConfig.from_dict(robot_config_dict, self.tensor_args)
self.kinematics.update_kinematics_config(robot_cfg.kinematics.kinematics_config) self.kinematics.update_kinematics_config(robot_cfg.kinematics.kinematics_config)
def check_start_state(
self, start_state: JointState
) -> Tuple[bool, Union[None, MotionGenStatus]]:
"""Check if the start state is valid for motion generation.
Args:
start_state: Start joint state of the robot.
Returns:
Tuple[bool, MotionGenStatus]: Tuple containing True if the start state is valid and
the status of the start state.
"""
joint_position = start_state.position
if len(joint_position.shape) == 1:
joint_position = joint_position.unsqueeze(0)
if len(joint_position.shape) > 2:
log_error("joint_position should be of shape (batch, dof)")
joint_position = joint_position.unsqueeze(1)
metrics = self.rollout_fn.rollout_constraint(
joint_position,
use_batch_env=False,
)
valid_query = metrics.feasible.squeeze(1).item()
status = None
if not valid_query:
self.rollout_fn.primitive_collision_constraint.disable_cost()
self.rollout_fn.robot_self_collision_constraint.disable_cost()
within_joint_limits = (
self.rollout_fn.rollout_constraint(
joint_position,
use_batch_env=False,
)
.feasible.squeeze(1)
.item()
)
self.rollout_fn.primitive_collision_constraint.enable_cost()
if not within_joint_limits:
self.rollout_fn.robot_self_collision_constraint.enable_cost()
return valid_query, MotionGenStatus.INVALID_START_STATE_JOINT_LIMITS
self.rollout_fn.primitive_collision_constraint.enable_cost()
world_collision_free = (
self.rollout_fn.rollout_constraint(
joint_position,
use_batch_env=False,
)
.feasible.squeeze(1)
.item()
)
if not world_collision_free:
return valid_query, MotionGenStatus.INVALID_START_STATE_WORLD_COLLISION
self.rollout_fn.robot_self_collision_constraint.enable_cost()
self_collision_free = (
self.rollout_fn.rollout_constraint(
joint_position,
use_batch_env=False,
)
.feasible.squeeze(1)
.item()
)
if not self_collision_free:
return valid_query, MotionGenStatus.INVALID_START_STATE_SELF_COLLISION
status = MotionGenStatus.INVALID_START_STATE_UNKNOWN_ISSUE
return (valid_query, status)
@profiler.record_function("motion_gen/ik") @profiler.record_function("motion_gen/ik")
def _solve_ik_from_solve_state( def _solve_ik_from_solve_state(
self, self,
@@ -2780,13 +2857,12 @@ class MotionGen(MotionGenConfig):
start_time = time.time() start_time = time.time()
valid_query = True valid_query = True
if plan_config.check_start_validity: if plan_config.check_start_validity:
# check if start state is collision-free and within limits: valid_query, status = self.check_start_state(start_state)
valid_query = self.ik_solver.check_valid(start_state.position).item()
if not valid_query: if not valid_query:
result = MotionGenResult( result = MotionGenResult(
success=torch.as_tensor([False], device=self.tensor_args.device), success=torch.as_tensor([False], device=self.tensor_args.device),
valid_query=valid_query, valid_query=valid_query,
status=MotionGenStatus.INVALID_START, status=status,
) )
return result return result
if plan_config.pose_cost_metric is not None: if plan_config.pose_cost_metric is not None:

View File

@@ -335,7 +335,16 @@ def test_motion_gen_single_js(motion_gen_str, enable_graph, request):
assert torch.norm(goal_state.position - reached_state.position) < 0.05 assert torch.norm(goal_state.position - reached_state.position) < 0.05
def test_motion_gen_single_js_invalid_start(motion_gen): @pytest.mark.parametrize(
"motion_gen_str,invalid_status",
[
("motion_gen", MotionGenStatus.INVALID_START_STATE_JOINT_LIMITS),
("motion_gen", MotionGenStatus.INVALID_START_STATE_SELF_COLLISION),
("motion_gen", MotionGenStatus.INVALID_START_STATE_WORLD_COLLISION),
],
)
def test_motion_gen_single_js_invalid_start(motion_gen_str, invalid_status, request):
motion_gen = request.getfixturevalue(motion_gen_str)
motion_gen.reset() motion_gen.reset()
@@ -347,18 +356,32 @@ def test_motion_gen_single_js_invalid_start(motion_gen):
goal_state = start_state.clone() goal_state = start_state.clone()
goal_state.position -= 0.3 goal_state.position -= 0.3
start_state.position[0] += 10.0 if invalid_status == MotionGenStatus.INVALID_START_STATE_JOINT_LIMITS:
start_state.position[0, 0] += 10.0
if invalid_status == MotionGenStatus.INVALID_START_STATE_SELF_COLLISION:
start_state.position[0, 3] = -3.0
if invalid_status == MotionGenStatus.INVALID_START_STATE_WORLD_COLLISION:
start_state.position[0, 1] = 1.7
result = motion_gen.plan_single_js(start_state, goal_state, m_config) result = motion_gen.plan_single_js(start_state, goal_state, m_config)
assert torch.count_nonzero(result.success) == 0 assert torch.count_nonzero(result.success) == 0
assert result.valid_query == False assert result.valid_query == False
assert result.status == MotionGenStatus.INVALID_START assert result.status == invalid_status
def test_motion_gen_single_invalid(motion_gen): @pytest.mark.parametrize(
"motion_gen_str,invalid_status",
[
("motion_gen", MotionGenStatus.INVALID_START_STATE_JOINT_LIMITS),
("motion_gen", MotionGenStatus.INVALID_START_STATE_SELF_COLLISION),
("motion_gen", MotionGenStatus.INVALID_START_STATE_WORLD_COLLISION),
],
)
def test_motion_gen_single_invalid(motion_gen_str, invalid_status, request):
motion_gen = request.getfixturevalue(motion_gen_str)
motion_gen.reset() motion_gen.reset()
retract_cfg = motion_gen.get_retract_config() retract_cfg = motion_gen.get_retract_config()
@@ -368,7 +391,12 @@ def test_motion_gen_single_invalid(motion_gen):
goal_pose = Pose(state.ee_pos_seq, quaternion=state.ee_quat_seq) goal_pose = Pose(state.ee_pos_seq, quaternion=state.ee_quat_seq)
start_state = JointState.from_position(retract_cfg.view(1, -1) + 0.3) start_state = JointState.from_position(retract_cfg.view(1, -1) + 0.3)
start_state.position[..., 1] = 1.7 if invalid_status == MotionGenStatus.INVALID_START_STATE_JOINT_LIMITS:
start_state.position[0, 0] += 10.0
if invalid_status == MotionGenStatus.INVALID_START_STATE_SELF_COLLISION:
start_state.position[0, 3] = -3.0
if invalid_status == MotionGenStatus.INVALID_START_STATE_WORLD_COLLISION:
start_state.position[0, 1] = 1.7
m_config = MotionGenPlanConfig(False, True, max_attempts=1) m_config = MotionGenPlanConfig(False, True, max_attempts=1)
@@ -379,4 +407,4 @@ def test_motion_gen_single_invalid(motion_gen):
assert result.valid_query == False assert result.valid_query == False
assert result.status == MotionGenStatus.INVALID_START assert result.status == invalid_status