constrained planning, robot segmentation
This commit is contained in:
@@ -39,7 +39,7 @@ from curobo.rollout.rollout_base import Goal, RolloutBase, RolloutConfig, Rollou
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.robot import CSpaceConfig, RobotConfig
|
||||
from curobo.types.state import JointState
|
||||
from curobo.util.logger import log_info, log_warn
|
||||
from curobo.util.logger import log_error, log_info, log_warn
|
||||
from curobo.util.tensor_util import cat_sum
|
||||
|
||||
|
||||
@@ -366,6 +366,7 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
)
|
||||
cost_list.append(coll_cost)
|
||||
if return_list:
|
||||
|
||||
return cost_list
|
||||
cost = cat_sum(cost_list)
|
||||
return cost
|
||||
@@ -424,6 +425,7 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
out_metrics = self.constraint_fn(state)
|
||||
out_metrics.state = state
|
||||
out_metrics = self.convergence_fn(state, out_metrics)
|
||||
out_metrics.cost = self.cost_fn(state)
|
||||
return out_metrics
|
||||
|
||||
def get_metrics_cuda_graph(self, state: JointState):
|
||||
@@ -451,6 +453,8 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
with torch.cuda.graph(self.cu_metrics_graph, stream=s):
|
||||
self._cu_out_metrics = self.get_metrics(self._cu_metrics_state_in)
|
||||
self._metrics_cuda_graph_init = True
|
||||
if self._cu_metrics_state_in.position.shape != state.position.shape:
|
||||
log_error("cuda graph changed")
|
||||
self._cu_metrics_state_in.copy_(state)
|
||||
self.cu_metrics_graph.replay()
|
||||
out_metrics = self._cu_out_metrics
|
||||
@@ -462,17 +466,6 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
):
|
||||
if out_metrics is None:
|
||||
out_metrics = RolloutMetrics()
|
||||
if (
|
||||
self.convergence_cfg.null_space_cfg is not None
|
||||
and self.null_convergence.enabled
|
||||
and self._goal_buffer.batch_retract_state_idx is not None
|
||||
):
|
||||
out_metrics.cost = self.null_convergence.forward_target_idx(
|
||||
self._goal_buffer.retract_state,
|
||||
state.state_seq.position,
|
||||
self._goal_buffer.batch_retract_state_idx,
|
||||
)
|
||||
|
||||
return out_metrics
|
||||
|
||||
def _get_augmented_state(self, state: JointState) -> KinematicModelState:
|
||||
@@ -688,9 +681,11 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
act_seq = self.dynamics_model.init_action_mean.unsqueeze(0).repeat(self.batch_size, 1, 1)
|
||||
return act_seq
|
||||
|
||||
def reset_cuda_graph(self):
|
||||
def reset_shape(self):
|
||||
self._goal_idx_update = True
|
||||
super().reset_shape()
|
||||
|
||||
def reset_cuda_graph(self):
|
||||
super().reset_cuda_graph()
|
||||
|
||||
def get_action_from_state(self, state: JointState):
|
||||
|
||||
Reference in New Issue
Block a user