improved joint space planning

This commit is contained in:
Balakumar Sundaralingam
2024-05-30 14:42:22 -07:00
parent 3bfed9d773
commit 0c51dd2da8
28 changed files with 1135 additions and 213 deletions

View File

@@ -261,7 +261,7 @@ class ArmBase(RolloutBase, ArmBaseConfig):
log_warn(
"null space cost is deprecated, use null_space_weight in bound cost instead"
)
self.cost_cfg.bound_cfg.dof = self.n_dofs
self.bound_cost = BoundCost(self.cost_cfg.bound_cfg)
if self.cost_cfg.manipulability_cfg is not None:
@@ -315,10 +315,12 @@ class ArmBase(RolloutBase, ArmBaseConfig):
self.cost_cfg.bound_cfg.state_finite_difference_mode = (
self.dynamics_model.state_finite_difference_mode
)
self.cost_cfg.bound_cfg.dof = self.n_dofs
self.constraint_cfg.bound_cfg.dof = self.n_dofs
self.bound_constraint = BoundCost(self.constraint_cfg.bound_cfg)
if self.convergence_cfg.null_space_cfg is not None:
self.convergence_cfg.null_space_cfg.dof = self.n_dofs
self.null_convergence = DistCost(self.convergence_cfg.null_space_cfg)
# set start state:
@@ -578,6 +580,7 @@ class ArmBase(RolloutBase, ArmBaseConfig):
----------
action_seq: torch.Tensor [num_particles, horizon, d_act]
"""
# print(act_seq.shape, self._goal_buffer.batch_current_state_idx)
if self.start_state is None:
raise ValueError("start_state is not set in rollout")
@@ -585,6 +588,7 @@ class ArmBase(RolloutBase, ArmBaseConfig):
state = self.dynamics_model.forward(
self.start_state, act_seq, self._goal_buffer.batch_current_state_idx
)
with profiler.record_function("cost/all"):
cost_seq = self.cost_fn(state, act_seq)