diff --git a/CHANGELOG.md b/CHANGELOG.md index 685a6fc..36b0b08 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ its affiliates is strictly prohibited. ## Latest Commit ### New Features -- Added validity of start state check for motion_gen plan calls for single queries. +- Add start state checks for world collision, self-collision, and joint limits. ### BugFixes & Misc. - Fix bug in evaluator to account for dof maximum acceleration and jerk. diff --git a/src/curobo/wrap/reacher/motion_gen.py b/src/curobo/wrap/reacher/motion_gen.py index 082b199..004ec31 100644 --- a/src/curobo/wrap/reacher/motion_gen.py +++ b/src/curobo/wrap/reacher/motion_gen.py @@ -41,7 +41,7 @@ import math import time from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union # Third Party import numpy as np @@ -1006,9 +1006,17 @@ class MotionGenStatus(Enum): #: pose/ joint target in joint limits. INVALID_QUERY = "Invalid Query" - #: Invalid query was given. The start state is either out of joint limits, in collision with - #: world, or in self-collision. - INVALID_START = "Invalid Start" + #: Invalid start state was given. Unknown reason. + INVALID_START_STATE_UNKNOWN_ISSUE = "Invalid Start State, unknown issue" + + #: Invalid start state was given. The start state is in world collision. + INVALID_START_STATE_WORLD_COLLISION = "Start state is colliding with world" + + #: Invalid start state was given. The start state is in self-collision. + INVALID_START_STATE_SELF_COLLISION = "Start state is in self-collision" + + #: Invalid start state was given. The start state is out of joint limits. + INVALID_START_STATE_JOINT_LIMITS = "Start state is out of joint limits" #: Motion generation query was successful. SUCCESS = "Success" @@ -1986,15 +1994,14 @@ class MotionGen(MotionGenConfig): n_goalset=1, ) force_graph = plan_config.enable_graph - + valid_query = True if plan_config.check_start_validity: - # check if start state is collision-free and within limits: - valid_query = self.ik_solver.check_valid(start_state.position).item() + valid_query, status = self.check_start_state(start_state) if not valid_query: result = MotionGenResult( success=torch.as_tensor([False], device=self.tensor_args.device), valid_query=valid_query, - status=MotionGenStatus.INVALID_START, + status=status, ) return result @@ -2566,6 +2573,76 @@ class MotionGen(MotionGenConfig): robot_cfg = RobotConfig.from_dict(robot_config_dict, self.tensor_args) self.kinematics.update_kinematics_config(robot_cfg.kinematics.kinematics_config) + def check_start_state( + self, start_state: JointState + ) -> Tuple[bool, Union[None, MotionGenStatus]]: + """Check if the start state is valid for motion generation. + + Args: + start_state: Start joint state of the robot. + + Returns: + Tuple[bool, MotionGenStatus]: Tuple containing True if the start state is valid and + the status of the start state. + """ + joint_position = start_state.position + + if len(joint_position.shape) == 1: + joint_position = joint_position.unsqueeze(0) + if len(joint_position.shape) > 2: + log_error("joint_position should be of shape (batch, dof)") + joint_position = joint_position.unsqueeze(1) + metrics = self.rollout_fn.rollout_constraint( + joint_position, + use_batch_env=False, + ) + valid_query = metrics.feasible.squeeze(1).item() + status = None + if not valid_query: + self.rollout_fn.primitive_collision_constraint.disable_cost() + self.rollout_fn.robot_self_collision_constraint.disable_cost() + within_joint_limits = ( + self.rollout_fn.rollout_constraint( + joint_position, + use_batch_env=False, + ) + .feasible.squeeze(1) + .item() + ) + + self.rollout_fn.primitive_collision_constraint.enable_cost() + + if not within_joint_limits: + self.rollout_fn.robot_self_collision_constraint.enable_cost() + return valid_query, MotionGenStatus.INVALID_START_STATE_JOINT_LIMITS + + self.rollout_fn.primitive_collision_constraint.enable_cost() + world_collision_free = ( + self.rollout_fn.rollout_constraint( + joint_position, + use_batch_env=False, + ) + .feasible.squeeze(1) + .item() + ) + if not world_collision_free: + return valid_query, MotionGenStatus.INVALID_START_STATE_WORLD_COLLISION + + self.rollout_fn.robot_self_collision_constraint.enable_cost() + self_collision_free = ( + self.rollout_fn.rollout_constraint( + joint_position, + use_batch_env=False, + ) + .feasible.squeeze(1) + .item() + ) + + if not self_collision_free: + return valid_query, MotionGenStatus.INVALID_START_STATE_SELF_COLLISION + status = MotionGenStatus.INVALID_START_STATE_UNKNOWN_ISSUE + return (valid_query, status) + @profiler.record_function("motion_gen/ik") def _solve_ik_from_solve_state( self, @@ -2780,13 +2857,12 @@ class MotionGen(MotionGenConfig): start_time = time.time() valid_query = True if plan_config.check_start_validity: - # check if start state is collision-free and within limits: - valid_query = self.ik_solver.check_valid(start_state.position).item() + valid_query, status = self.check_start_state(start_state) if not valid_query: result = MotionGenResult( success=torch.as_tensor([False], device=self.tensor_args.device), valid_query=valid_query, - status=MotionGenStatus.INVALID_START, + status=status, ) return result if plan_config.pose_cost_metric is not None: diff --git a/tests/motion_gen_module_test.py b/tests/motion_gen_module_test.py index 67b60aa..267ae61 100644 --- a/tests/motion_gen_module_test.py +++ b/tests/motion_gen_module_test.py @@ -335,7 +335,16 @@ def test_motion_gen_single_js(motion_gen_str, enable_graph, request): assert torch.norm(goal_state.position - reached_state.position) < 0.05 -def test_motion_gen_single_js_invalid_start(motion_gen): +@pytest.mark.parametrize( + "motion_gen_str,invalid_status", + [ + ("motion_gen", MotionGenStatus.INVALID_START_STATE_JOINT_LIMITS), + ("motion_gen", MotionGenStatus.INVALID_START_STATE_SELF_COLLISION), + ("motion_gen", MotionGenStatus.INVALID_START_STATE_WORLD_COLLISION), + ], +) +def test_motion_gen_single_js_invalid_start(motion_gen_str, invalid_status, request): + motion_gen = request.getfixturevalue(motion_gen_str) motion_gen.reset() @@ -347,18 +356,32 @@ def test_motion_gen_single_js_invalid_start(motion_gen): goal_state = start_state.clone() goal_state.position -= 0.3 - start_state.position[0] += 10.0 - + if invalid_status == MotionGenStatus.INVALID_START_STATE_JOINT_LIMITS: + start_state.position[0, 0] += 10.0 + if invalid_status == MotionGenStatus.INVALID_START_STATE_SELF_COLLISION: + start_state.position[0, 3] = -3.0 + if invalid_status == MotionGenStatus.INVALID_START_STATE_WORLD_COLLISION: + start_state.position[0, 1] = 1.7 result = motion_gen.plan_single_js(start_state, goal_state, m_config) assert torch.count_nonzero(result.success) == 0 assert result.valid_query == False - assert result.status == MotionGenStatus.INVALID_START + assert result.status == invalid_status -def test_motion_gen_single_invalid(motion_gen): +@pytest.mark.parametrize( + "motion_gen_str,invalid_status", + [ + ("motion_gen", MotionGenStatus.INVALID_START_STATE_JOINT_LIMITS), + ("motion_gen", MotionGenStatus.INVALID_START_STATE_SELF_COLLISION), + ("motion_gen", MotionGenStatus.INVALID_START_STATE_WORLD_COLLISION), + ], +) +def test_motion_gen_single_invalid(motion_gen_str, invalid_status, request): + motion_gen = request.getfixturevalue(motion_gen_str) + motion_gen.reset() retract_cfg = motion_gen.get_retract_config() @@ -368,7 +391,12 @@ def test_motion_gen_single_invalid(motion_gen): goal_pose = Pose(state.ee_pos_seq, quaternion=state.ee_quat_seq) start_state = JointState.from_position(retract_cfg.view(1, -1) + 0.3) - start_state.position[..., 1] = 1.7 + if invalid_status == MotionGenStatus.INVALID_START_STATE_JOINT_LIMITS: + start_state.position[0, 0] += 10.0 + if invalid_status == MotionGenStatus.INVALID_START_STATE_SELF_COLLISION: + start_state.position[0, 3] = -3.0 + if invalid_status == MotionGenStatus.INVALID_START_STATE_WORLD_COLLISION: + start_state.position[0, 1] = 1.7 m_config = MotionGenPlanConfig(False, True, max_attempts=1) @@ -379,4 +407,4 @@ def test_motion_gen_single_invalid(motion_gen): assert result.valid_query == False - assert result.status == MotionGenStatus.INVALID_START + assert result.status == invalid_status