improve start state validity check

This commit is contained in:
Balakumar Sundaralingam
2024-05-15 12:09:23 -07:00
parent 911da8cb24
commit 3bfed9d773
3 changed files with 123 additions and 19 deletions

View File

@@ -41,7 +41,7 @@ import math
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
# Third Party
import numpy as np
@@ -1006,9 +1006,17 @@ class MotionGenStatus(Enum):
#: pose/ joint target in joint limits.
INVALID_QUERY = "Invalid Query"
#: Invalid query was given. The start state is either out of joint limits, in collision with
#: world, or in self-collision.
INVALID_START = "Invalid Start"
#: Invalid start state was given. Unknown reason.
INVALID_START_STATE_UNKNOWN_ISSUE = "Invalid Start State, unknown issue"
#: Invalid start state was given. The start state is in world collision.
INVALID_START_STATE_WORLD_COLLISION = "Start state is colliding with world"
#: Invalid start state was given. The start state is in self-collision.
INVALID_START_STATE_SELF_COLLISION = "Start state is in self-collision"
#: Invalid start state was given. The start state is out of joint limits.
INVALID_START_STATE_JOINT_LIMITS = "Start state is out of joint limits"
#: Motion generation query was successful.
SUCCESS = "Success"
@@ -1986,15 +1994,14 @@ class MotionGen(MotionGenConfig):
n_goalset=1,
)
force_graph = plan_config.enable_graph
valid_query = True
if plan_config.check_start_validity:
# check if start state is collision-free and within limits:
valid_query = self.ik_solver.check_valid(start_state.position).item()
valid_query, status = self.check_start_state(start_state)
if not valid_query:
result = MotionGenResult(
success=torch.as_tensor([False], device=self.tensor_args.device),
valid_query=valid_query,
status=MotionGenStatus.INVALID_START,
status=status,
)
return result
@@ -2566,6 +2573,76 @@ class MotionGen(MotionGenConfig):
robot_cfg = RobotConfig.from_dict(robot_config_dict, self.tensor_args)
self.kinematics.update_kinematics_config(robot_cfg.kinematics.kinematics_config)
def check_start_state(
self, start_state: JointState
) -> Tuple[bool, Union[None, MotionGenStatus]]:
"""Check if the start state is valid for motion generation.
Args:
start_state: Start joint state of the robot.
Returns:
Tuple[bool, MotionGenStatus]: Tuple containing True if the start state is valid and
the status of the start state.
"""
joint_position = start_state.position
if len(joint_position.shape) == 1:
joint_position = joint_position.unsqueeze(0)
if len(joint_position.shape) > 2:
log_error("joint_position should be of shape (batch, dof)")
joint_position = joint_position.unsqueeze(1)
metrics = self.rollout_fn.rollout_constraint(
joint_position,
use_batch_env=False,
)
valid_query = metrics.feasible.squeeze(1).item()
status = None
if not valid_query:
self.rollout_fn.primitive_collision_constraint.disable_cost()
self.rollout_fn.robot_self_collision_constraint.disable_cost()
within_joint_limits = (
self.rollout_fn.rollout_constraint(
joint_position,
use_batch_env=False,
)
.feasible.squeeze(1)
.item()
)
self.rollout_fn.primitive_collision_constraint.enable_cost()
if not within_joint_limits:
self.rollout_fn.robot_self_collision_constraint.enable_cost()
return valid_query, MotionGenStatus.INVALID_START_STATE_JOINT_LIMITS
self.rollout_fn.primitive_collision_constraint.enable_cost()
world_collision_free = (
self.rollout_fn.rollout_constraint(
joint_position,
use_batch_env=False,
)
.feasible.squeeze(1)
.item()
)
if not world_collision_free:
return valid_query, MotionGenStatus.INVALID_START_STATE_WORLD_COLLISION
self.rollout_fn.robot_self_collision_constraint.enable_cost()
self_collision_free = (
self.rollout_fn.rollout_constraint(
joint_position,
use_batch_env=False,
)
.feasible.squeeze(1)
.item()
)
if not self_collision_free:
return valid_query, MotionGenStatus.INVALID_START_STATE_SELF_COLLISION
status = MotionGenStatus.INVALID_START_STATE_UNKNOWN_ISSUE
return (valid_query, status)
@profiler.record_function("motion_gen/ik")
def _solve_ik_from_solve_state(
self,
@@ -2780,13 +2857,12 @@ class MotionGen(MotionGenConfig):
start_time = time.time()
valid_query = True
if plan_config.check_start_validity:
# check if start state is collision-free and within limits:
valid_query = self.ik_solver.check_valid(start_state.position).item()
valid_query, status = self.check_start_state(start_state)
if not valid_query:
result = MotionGenResult(
success=torch.as_tensor([False], device=self.tensor_args.device),
valid_query=valid_query,
status=MotionGenStatus.INVALID_START,
status=status,
)
return result
if plan_config.pose_cost_metric is not None: