constrained planning, robot segmentation

This commit is contained in:
Balakumar Sundaralingam
2024-02-22 21:45:47 -08:00
parent 88eac64edc
commit bafdf80c05
102 changed files with 12440 additions and 8112 deletions

View File

@@ -39,7 +39,7 @@ from curobo.rollout.rollout_base import Goal, RolloutBase, RolloutConfig, Rollou
from curobo.types.base import TensorDeviceType
from curobo.types.robot import CSpaceConfig, RobotConfig
from curobo.types.state import JointState
from curobo.util.logger import log_info, log_warn
from curobo.util.logger import log_error, log_info, log_warn
from curobo.util.tensor_util import cat_sum
@@ -366,6 +366,7 @@ class ArmBase(RolloutBase, ArmBaseConfig):
)
cost_list.append(coll_cost)
if return_list:
return cost_list
cost = cat_sum(cost_list)
return cost
@@ -424,6 +425,7 @@ class ArmBase(RolloutBase, ArmBaseConfig):
out_metrics = self.constraint_fn(state)
out_metrics.state = state
out_metrics = self.convergence_fn(state, out_metrics)
out_metrics.cost = self.cost_fn(state)
return out_metrics
def get_metrics_cuda_graph(self, state: JointState):
@@ -451,6 +453,8 @@ class ArmBase(RolloutBase, ArmBaseConfig):
with torch.cuda.graph(self.cu_metrics_graph, stream=s):
self._cu_out_metrics = self.get_metrics(self._cu_metrics_state_in)
self._metrics_cuda_graph_init = True
if self._cu_metrics_state_in.position.shape != state.position.shape:
log_error("cuda graph changed")
self._cu_metrics_state_in.copy_(state)
self.cu_metrics_graph.replay()
out_metrics = self._cu_out_metrics
@@ -462,17 +466,6 @@ class ArmBase(RolloutBase, ArmBaseConfig):
):
if out_metrics is None:
out_metrics = RolloutMetrics()
if (
self.convergence_cfg.null_space_cfg is not None
and self.null_convergence.enabled
and self._goal_buffer.batch_retract_state_idx is not None
):
out_metrics.cost = self.null_convergence.forward_target_idx(
self._goal_buffer.retract_state,
state.state_seq.position,
self._goal_buffer.batch_retract_state_idx,
)
return out_metrics
def _get_augmented_state(self, state: JointState) -> KinematicModelState:
@@ -688,9 +681,11 @@ class ArmBase(RolloutBase, ArmBaseConfig):
act_seq = self.dynamics_model.init_action_mean.unsqueeze(0).repeat(self.batch_size, 1, 1)
return act_seq
def reset_cuda_graph(self):
def reset_shape(self):
self._goal_idx_update = True
super().reset_shape()
def reset_cuda_graph(self):
super().reset_cuda_graph()
def get_action_from_state(self, state: JointState):