update to 0.6.2

This commit is contained in:
Balakumar Sundaralingam
2023-12-15 02:01:33 -08:00
parent d85ae41fba
commit 58958bbcce
105 changed files with 2514 additions and 934 deletions

View File

@@ -324,9 +324,9 @@ class ArmBase(RolloutBase, ArmBaseConfig):
# set start state:
start_state = torch.randn((1, self.dynamics_model.d_state), **vars(self.tensor_args))
self._start_state = JointState(
position=start_state[:, : self.dynamics_model.d_action],
velocity=start_state[:, : self.dynamics_model.d_action],
acceleration=start_state[:, : self.dynamics_model.d_action],
position=start_state[:, : self.dynamics_model.d_dof],
velocity=start_state[:, : self.dynamics_model.d_dof],
acceleration=start_state[:, : self.dynamics_model.d_dof],
)
self.update_cost_dt(self.dynamics_model.dt_traj_params.base_dt)
return RolloutBase._init_after_config_load(self)
@@ -680,6 +680,10 @@ class ArmBase(RolloutBase, ArmBaseConfig):
def horizon(self):
return self.dynamics_model.horizon
@property
def action_horizon(self):
return self.dynamics_model.action_horizon
def get_init_action_seq(self) -> torch.Tensor:
act_seq = self.dynamics_model.init_action_mean.unsqueeze(0).repeat(self.batch_size, 1, 1)
return act_seq