Isaac Sim 4.0 support, Kinematics API doc, Windows support
This commit is contained in:
@@ -421,7 +421,7 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
|
||||
def get_metrics(self, state: Union[JointState, KinematicModelState]):
|
||||
"""Compute metrics given state
|
||||
#TODO: Currently does not compute velocity and acceleration costs.
|
||||
|
||||
Args:
|
||||
state (Union[JointState, URDFModelState]): _description_
|
||||
|
||||
@@ -429,6 +429,8 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
_type_: _description_
|
||||
|
||||
"""
|
||||
if self.cuda_graph_instance:
|
||||
log_error("Cuda graph is using this instance, please break the graph before using this")
|
||||
if isinstance(state, JointState):
|
||||
state = self._get_augmented_state(state)
|
||||
out_metrics = self.constraint_fn(state)
|
||||
@@ -462,6 +464,9 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
with torch.cuda.graph(self.cu_metrics_graph, stream=s):
|
||||
self._cu_out_metrics = self.get_metrics(self._cu_metrics_state_in)
|
||||
self._metrics_cuda_graph_init = True
|
||||
self._cuda_graph_valid = True
|
||||
if not self.cuda_graph_instance:
|
||||
log_error("cuda graph is invalid")
|
||||
if self._cu_metrics_state_in.position.shape != state.position.shape:
|
||||
log_error("cuda graph changed")
|
||||
self._cu_metrics_state_in.copy_(state)
|
||||
@@ -541,6 +546,8 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
def rollout_constraint(
|
||||
self, act_seq: torch.Tensor, use_batch_env: bool = True
|
||||
) -> RolloutMetrics:
|
||||
if self.cuda_graph_instance:
|
||||
log_error("Cuda graph is using this instance, please break the graph before using this")
|
||||
state = self.dynamics_model.forward(self.start_state, act_seq)
|
||||
metrics = self.constraint_fn(state, use_batch_env=use_batch_env)
|
||||
return metrics
|
||||
@@ -566,6 +573,9 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
state, use_batch_env=use_batch_env
|
||||
)
|
||||
self._rollout_constraint_cuda_graph_init = True
|
||||
self._cuda_graph_valid = True
|
||||
if not self.cuda_graph_instance:
|
||||
log_error("cuda graph is invalid")
|
||||
self._cu_rollout_constraint_act_in.copy_(act_seq)
|
||||
self.cu_rollout_constraint_graph.replay()
|
||||
out_metrics = self._cu_rollout_constraint_out_metrics
|
||||
@@ -602,12 +612,8 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
|
||||
"""
|
||||
with profiler.record_function("arm_base/update_params"):
|
||||
self._goal_buffer.copy_(
|
||||
goal, update_idx_buffers=self._goal_idx_update
|
||||
) # TODO: convert this to a reference to avoid extra copy
|
||||
# self._goal_buffer.copy_(goal, update_idx_buffers=True) # TODO: convert this to a reference to avoid extra copy
|
||||
self._goal_buffer.copy_(goal, update_idx_buffers=self._goal_idx_update)
|
||||
|
||||
# TODO: move start state also inside Goal instance
|
||||
if goal.current_state is not None:
|
||||
if self.start_state is None:
|
||||
self.start_state = goal.current_state.clone()
|
||||
|
||||
Reference in New Issue
Block a user