Isaac Sim 4.0 support, Kinematics API doc, Windows support

This commit is contained in:
Balakumar Sundaralingam
2024-07-20 14:51:43 -07:00
parent 2ae381f328
commit 3690d28c54
83 changed files with 2818 additions and 497 deletions

View File

@@ -421,7 +421,7 @@ class ArmBase(RolloutBase, ArmBaseConfig):
def get_metrics(self, state: Union[JointState, KinematicModelState]):
"""Compute metrics given state
#TODO: Currently does not compute velocity and acceleration costs.
Args:
state (Union[JointState, URDFModelState]): _description_
@@ -429,6 +429,8 @@ class ArmBase(RolloutBase, ArmBaseConfig):
_type_: _description_
"""
if self.cuda_graph_instance:
log_error("Cuda graph is using this instance, please break the graph before using this")
if isinstance(state, JointState):
state = self._get_augmented_state(state)
out_metrics = self.constraint_fn(state)
@@ -462,6 +464,9 @@ class ArmBase(RolloutBase, ArmBaseConfig):
with torch.cuda.graph(self.cu_metrics_graph, stream=s):
self._cu_out_metrics = self.get_metrics(self._cu_metrics_state_in)
self._metrics_cuda_graph_init = True
self._cuda_graph_valid = True
if not self.cuda_graph_instance:
log_error("cuda graph is invalid")
if self._cu_metrics_state_in.position.shape != state.position.shape:
log_error("cuda graph changed")
self._cu_metrics_state_in.copy_(state)
@@ -541,6 +546,8 @@ class ArmBase(RolloutBase, ArmBaseConfig):
def rollout_constraint(
self, act_seq: torch.Tensor, use_batch_env: bool = True
) -> RolloutMetrics:
if self.cuda_graph_instance:
log_error("Cuda graph is using this instance, please break the graph before using this")
state = self.dynamics_model.forward(self.start_state, act_seq)
metrics = self.constraint_fn(state, use_batch_env=use_batch_env)
return metrics
@@ -566,6 +573,9 @@ class ArmBase(RolloutBase, ArmBaseConfig):
state, use_batch_env=use_batch_env
)
self._rollout_constraint_cuda_graph_init = True
self._cuda_graph_valid = True
if not self.cuda_graph_instance:
log_error("cuda graph is invalid")
self._cu_rollout_constraint_act_in.copy_(act_seq)
self.cu_rollout_constraint_graph.replay()
out_metrics = self._cu_rollout_constraint_out_metrics
@@ -602,12 +612,8 @@ class ArmBase(RolloutBase, ArmBaseConfig):
"""
with profiler.record_function("arm_base/update_params"):
self._goal_buffer.copy_(
goal, update_idx_buffers=self._goal_idx_update
) # TODO: convert this to a reference to avoid extra copy
# self._goal_buffer.copy_(goal, update_idx_buffers=True) # TODO: convert this to a reference to avoid extra copy
self._goal_buffer.copy_(goal, update_idx_buffers=self._goal_idx_update)
# TODO: move start state also inside Goal instance
if goal.current_state is not None:
if self.start_state is None:
self.start_state = goal.current_state.clone()