update to 0.6.2
This commit is contained in:
@@ -324,9 +324,9 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
# set start state:
|
||||
start_state = torch.randn((1, self.dynamics_model.d_state), **vars(self.tensor_args))
|
||||
self._start_state = JointState(
|
||||
position=start_state[:, : self.dynamics_model.d_action],
|
||||
velocity=start_state[:, : self.dynamics_model.d_action],
|
||||
acceleration=start_state[:, : self.dynamics_model.d_action],
|
||||
position=start_state[:, : self.dynamics_model.d_dof],
|
||||
velocity=start_state[:, : self.dynamics_model.d_dof],
|
||||
acceleration=start_state[:, : self.dynamics_model.d_dof],
|
||||
)
|
||||
self.update_cost_dt(self.dynamics_model.dt_traj_params.base_dt)
|
||||
return RolloutBase._init_after_config_load(self)
|
||||
@@ -680,6 +680,10 @@ class ArmBase(RolloutBase, ArmBaseConfig):
|
||||
def horizon(self):
|
||||
return self.dynamics_model.horizon
|
||||
|
||||
@property
|
||||
def action_horizon(self):
|
||||
return self.dynamics_model.action_horizon
|
||||
|
||||
def get_init_action_seq(self) -> torch.Tensor:
|
||||
act_seq = self.dynamics_model.init_action_mean.unsqueeze(0).repeat(self.batch_size, 1, 1)
|
||||
return act_seq
|
||||
|
||||
@@ -29,6 +29,7 @@ from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.robot import RobotConfig
|
||||
from curobo.types.tensor import T_BValue_float
|
||||
from curobo.util.helpers import list_idx_if_not_none
|
||||
from curobo.util.logger import log_info
|
||||
from curobo.util.tensor_util import cat_max, cat_sum
|
||||
|
||||
# Local Folder
|
||||
@@ -79,6 +80,7 @@ class ArmReacherCostConfig(ArmCostConfig):
|
||||
zero_acc_cfg: Optional[CostConfig] = None
|
||||
zero_vel_cfg: Optional[CostConfig] = None
|
||||
zero_jerk_cfg: Optional[CostConfig] = None
|
||||
link_pose_cfg: Optional[PoseCostConfig] = None
|
||||
|
||||
@staticmethod
|
||||
def _get_base_keys():
|
||||
@@ -91,6 +93,7 @@ class ArmReacherCostConfig(ArmCostConfig):
|
||||
"zero_acc_cfg": CostConfig,
|
||||
"zero_vel_cfg": CostConfig,
|
||||
"zero_jerk_cfg": CostConfig,
|
||||
"link_pose_cfg": PoseCostConfig,
|
||||
}
|
||||
new_k.update(base_k)
|
||||
return new_k
|
||||
@@ -166,10 +169,17 @@ class ArmReacher(ArmBase, ArmReacherConfig):
|
||||
self.dist_cost = DistCost(self.cost_cfg.cspace_cfg)
|
||||
if self.cost_cfg.pose_cfg is not None:
|
||||
self.goal_cost = PoseCost(self.cost_cfg.pose_cfg)
|
||||
self._link_pose_costs = {}
|
||||
if self.cost_cfg.link_pose_cfg is None:
|
||||
log_info(
|
||||
"Deprecated: Add link_pose_cfg to your rollout config. Using pose_cfg instead."
|
||||
)
|
||||
self.cost_cfg.link_pose_cfg = self.cost_cfg.pose_cfg
|
||||
self._link_pose_costs = {}
|
||||
|
||||
if self.cost_cfg.link_pose_cfg is not None:
|
||||
for i in self.kinematics.link_names:
|
||||
if i != self.kinematics.ee_link:
|
||||
self._link_pose_costs[i] = PoseCost(self.cost_cfg.pose_cfg)
|
||||
self._link_pose_costs[i] = PoseCost(self.cost_cfg.link_pose_cfg)
|
||||
if self.cost_cfg.straight_line_cfg is not None:
|
||||
self.straight_line_cost = StraightLineCost(self.cost_cfg.straight_line_cfg)
|
||||
if self.cost_cfg.zero_vel_cfg is not None:
|
||||
@@ -192,12 +202,20 @@ class ArmReacher(ArmBase, ArmReacherConfig):
|
||||
self.z_tensor = torch.tensor(
|
||||
0, device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
)
|
||||
self._link_pose_convergence = {}
|
||||
|
||||
if self.convergence_cfg.pose_cfg is not None:
|
||||
self.pose_convergence = PoseCost(self.convergence_cfg.pose_cfg)
|
||||
self._link_pose_convergence = {}
|
||||
if self.convergence_cfg.link_pose_cfg is None:
|
||||
log_warn(
|
||||
"Deprecated: Add link_pose_cfg to your rollout config. Using pose_cfg instead."
|
||||
)
|
||||
self.convergence_cfg.link_pose_cfg = self.convergence_cfg.pose_cfg
|
||||
|
||||
if self.convergence_cfg.link_pose_cfg is not None:
|
||||
for i in self.kinematics.link_names:
|
||||
if i != self.kinematics.ee_link:
|
||||
self._link_pose_convergence[i] = PoseCost(self.convergence_cfg.pose_cfg)
|
||||
self._link_pose_convergence[i] = PoseCost(self.convergence_cfg.link_pose_cfg)
|
||||
if self.convergence_cfg.cspace_cfg is not None:
|
||||
self.cspace_convergence = DistCost(self.convergence_cfg.cspace_cfg)
|
||||
|
||||
@@ -307,6 +325,7 @@ class ArmReacher(ArmBase, ArmReacherConfig):
|
||||
# print(z_vel.shape)
|
||||
cost_list.append(z_vel)
|
||||
cost = cat_sum(cost_list)
|
||||
# print(cost[:].T)
|
||||
return cost
|
||||
|
||||
def convergence_fn(
|
||||
|
||||
@@ -691,8 +691,6 @@ class PoseCost(CostBase, PoseCostConfig):
|
||||
)
|
||||
|
||||
cost = distance
|
||||
|
||||
# print(cost.shape)
|
||||
return cost
|
||||
|
||||
def forward_pose(
|
||||
|
||||
@@ -340,9 +340,9 @@ class CliqueTensorStepKernel(torch.autograd.Function):
|
||||
grad_out_a,
|
||||
grad_out_j,
|
||||
traj_dt,
|
||||
out_grad_position.shape[0],
|
||||
out_grad_position.shape[1],
|
||||
out_grad_position.shape[2],
|
||||
grad_out_p.shape[0],
|
||||
grad_out_p.shape[1],
|
||||
grad_out_p.shape[2],
|
||||
)
|
||||
return (
|
||||
u_grad,
|
||||
@@ -412,9 +412,9 @@ class CliqueTensorStepIdxKernel(torch.autograd.Function):
|
||||
grad_out_a,
|
||||
grad_out_j,
|
||||
traj_dt,
|
||||
out_grad_position.shape[0],
|
||||
out_grad_position.shape[1],
|
||||
out_grad_position.shape[2],
|
||||
grad_out_p.shape[0],
|
||||
grad_out_p.shape[1],
|
||||
grad_out_p.shape[2],
|
||||
)
|
||||
return (
|
||||
u_grad,
|
||||
@@ -483,9 +483,9 @@ class CliqueTensorStepCentralDifferenceKernel(torch.autograd.Function):
|
||||
grad_out_a.contiguous(),
|
||||
grad_out_j.contiguous(),
|
||||
traj_dt,
|
||||
out_grad_position.shape[0],
|
||||
out_grad_position.shape[1],
|
||||
out_grad_position.shape[2],
|
||||
grad_out_p.shape[0],
|
||||
grad_out_p.shape[1],
|
||||
grad_out_p.shape[2],
|
||||
0,
|
||||
)
|
||||
return (
|
||||
@@ -557,9 +557,9 @@ class CliqueTensorStepIdxCentralDifferenceKernel(torch.autograd.Function):
|
||||
grad_out_a.contiguous(),
|
||||
grad_out_j.contiguous(),
|
||||
traj_dt,
|
||||
out_grad_position.shape[0],
|
||||
out_grad_position.shape[1],
|
||||
out_grad_position.shape[2],
|
||||
grad_out_p.shape[0],
|
||||
grad_out_p.shape[1],
|
||||
grad_out_p.shape[2],
|
||||
0,
|
||||
)
|
||||
return (
|
||||
@@ -626,9 +626,9 @@ class CliqueTensorStepCoalesceKernel(torch.autograd.Function):
|
||||
grad_out_a.transpose(-1, -2).contiguous(),
|
||||
grad_out_j.transpose(-1, -2).contiguous(),
|
||||
traj_dt,
|
||||
out_grad_position.shape[0],
|
||||
out_grad_position.shape[1],
|
||||
out_grad_position.shape[2],
|
||||
grad_out_p.shape[0],
|
||||
grad_out_p.shape[1],
|
||||
grad_out_p.shape[2],
|
||||
)
|
||||
return (
|
||||
u_grad.transpose(-1, -2).contiguous(),
|
||||
@@ -890,16 +890,16 @@ def interpolate_kernel(h, int_steps, tensor_args: TensorDeviceType):
|
||||
return mat
|
||||
|
||||
|
||||
def action_interpolate_kernel(h, int_steps, tensor_args: TensorDeviceType):
|
||||
def action_interpolate_kernel(h, int_steps, tensor_args: TensorDeviceType, offset: int = 4):
|
||||
mat = torch.zeros(
|
||||
((h - 1) * (int_steps), h), device=tensor_args.device, dtype=tensor_args.dtype
|
||||
)
|
||||
delta = torch.arange(0, int_steps - 2, device=tensor_args.device, dtype=tensor_args.dtype) / (
|
||||
int_steps - 1.0 - 2
|
||||
)
|
||||
delta = torch.arange(
|
||||
0, int_steps - offset + 1, device=tensor_args.device, dtype=tensor_args.dtype
|
||||
) / (int_steps - offset)
|
||||
for i in range(h - 1):
|
||||
mat[i * int_steps : i * (int_steps) + int_steps - 1 - 2, i] = delta.flip(0)[1:]
|
||||
mat[i * int_steps : i * (int_steps) + int_steps - 1 - 2, i + 1] = delta[1:]
|
||||
mat[-3:, 1] = 1.0
|
||||
mat[i * int_steps : i * (int_steps) + int_steps - offset, i] = delta.flip(0)[1:]
|
||||
mat[i * int_steps : i * (int_steps) + int_steps - offset, i + 1] = delta[1:]
|
||||
mat[-offset:, 1] = 1.0
|
||||
|
||||
return mat
|
||||
|
||||
@@ -176,6 +176,7 @@ class KinematicModel(KinematicModelConfig):
|
||||
self._use_clique_kernel = True
|
||||
self.d_state = 4 * self.n_dofs # + 1
|
||||
self.d_action = self.n_dofs
|
||||
self.d_dof = self.n_dofs
|
||||
|
||||
# Variables for enforcing joint limits
|
||||
self.joint_names = self.robot_model.joint_names
|
||||
@@ -190,7 +191,7 @@ class KinematicModel(KinematicModelConfig):
|
||||
device=self.tensor_args.device,
|
||||
dtype=self.tensor_args.dtype,
|
||||
),
|
||||
dof=int(self.d_state / 3),
|
||||
dof=int(self.d_dof),
|
||||
)
|
||||
self.Z = torch.tensor([0.0], device=self.tensor_args.device, dtype=self.tensor_args.dtype)
|
||||
|
||||
@@ -204,10 +205,18 @@ class KinematicModel(KinematicModelConfig):
|
||||
# self._cmd_step_fn = TensorStepAcceleration(self.tensor_args, self.traj_dt)
|
||||
|
||||
self._rollout_step_fn = TensorStepAccelerationKernel(
|
||||
self.tensor_args, self.traj_dt, self.n_dofs
|
||||
self.tensor_args,
|
||||
self.traj_dt,
|
||||
self.n_dofs,
|
||||
self.batch_size,
|
||||
self.horizon,
|
||||
)
|
||||
self._cmd_step_fn = TensorStepAccelerationKernel(
|
||||
self.tensor_args, self.traj_dt, self.n_dofs
|
||||
self.tensor_args,
|
||||
self.traj_dt,
|
||||
self.n_dofs,
|
||||
self.batch_size,
|
||||
self.horizon,
|
||||
)
|
||||
elif self.control_space == StateType.VELOCITY:
|
||||
raise NotImplementedError()
|
||||
@@ -215,8 +224,12 @@ class KinematicModel(KinematicModelConfig):
|
||||
raise NotImplementedError()
|
||||
elif self.control_space == StateType.POSITION:
|
||||
if self.teleport_mode:
|
||||
self._rollout_step_fn = TensorStepPositionTeleport(self.tensor_args)
|
||||
self._cmd_step_fn = TensorStepPositionTeleport(self.tensor_args)
|
||||
self._rollout_step_fn = TensorStepPositionTeleport(
|
||||
self.tensor_args, self.batch_size, self.horizon
|
||||
)
|
||||
self._cmd_step_fn = TensorStepPositionTeleport(
|
||||
self.tensor_args, self.batch_size, self.horizon
|
||||
)
|
||||
else:
|
||||
if self._use_clique:
|
||||
if self._use_clique_kernel:
|
||||
@@ -237,6 +250,8 @@ class KinematicModel(KinematicModelConfig):
|
||||
filter_velocity=False,
|
||||
filter_acceleration=False,
|
||||
filter_jerk=False,
|
||||
batch_size=self.batch_size,
|
||||
horizon=self.horizon,
|
||||
)
|
||||
self._cmd_step_fn = TensorStepPositionCliqueKernel(
|
||||
self.tensor_args,
|
||||
@@ -246,17 +261,36 @@ class KinematicModel(KinematicModelConfig):
|
||||
filter_velocity=False,
|
||||
filter_acceleration=self.filter_robot_command,
|
||||
filter_jerk=self.filter_robot_command,
|
||||
batch_size=self.batch_size,
|
||||
horizon=self.horizon,
|
||||
)
|
||||
|
||||
else:
|
||||
self._rollout_step_fn = TensorStepPositionClique(
|
||||
self.tensor_args, self.traj_dt
|
||||
self.tensor_args,
|
||||
self.traj_dt,
|
||||
batch_size=self.batch_size,
|
||||
horizon=self.horizon,
|
||||
)
|
||||
self._cmd_step_fn = TensorStepPositionClique(
|
||||
self.tensor_args,
|
||||
self.traj_dt,
|
||||
batch_size=self.batch_size,
|
||||
horizon=self.horizon,
|
||||
)
|
||||
self._cmd_step_fn = TensorStepPositionClique(self.tensor_args, self.traj_dt)
|
||||
else:
|
||||
self._rollout_step_fn = TensorStepPosition(self.tensor_args, self.traj_dt)
|
||||
self._cmd_step_fn = TensorStepPosition(self.tensor_args, self.traj_dt)
|
||||
|
||||
self._rollout_step_fn = TensorStepPosition(
|
||||
self.tensor_args,
|
||||
self.traj_dt,
|
||||
batch_size=self.batch_size,
|
||||
horizon=self.horizon,
|
||||
)
|
||||
self._cmd_step_fn = TensorStepPosition(
|
||||
self.tensor_args,
|
||||
self.traj_dt,
|
||||
batch_size=self.batch_size,
|
||||
horizon=self.horizon,
|
||||
)
|
||||
self.update_batch_size(self.batch_size)
|
||||
|
||||
self.state_filter = JointStateFilter(self.state_filter_cfg)
|
||||
@@ -542,10 +576,10 @@ class KinematicModel(KinematicModelConfig):
|
||||
# output should be d_action * horizon
|
||||
if self.control_space == StateType.POSITION:
|
||||
# use joint limits:
|
||||
return self.retract_config.unsqueeze(0).repeat(self.horizon, 1)
|
||||
return self.retract_config.unsqueeze(0).repeat(self.action_horizon, 1)
|
||||
if self.control_space == StateType.VELOCITY or self.control_space == StateType.ACCELERATION:
|
||||
# use joint limits:
|
||||
return self.retract_config.unsqueeze(0).repeat(self.horizon, 1) * 0.0
|
||||
return self.retract_config.unsqueeze(0).repeat(self.action_horizon, 1) * 0.0
|
||||
|
||||
@property
|
||||
def retract_config(self):
|
||||
@@ -567,6 +601,10 @@ class KinematicModel(KinematicModelConfig):
|
||||
def max_jerk(self):
|
||||
return self.get_state_bounds().jerk[1, 0].item()
|
||||
|
||||
@property
|
||||
def action_horizon(self):
|
||||
return self._rollout_step_fn.action_horizon
|
||||
|
||||
def get_state_bounds(self):
|
||||
joint_limits = self.robot_model.get_joint_limits()
|
||||
return joint_limits
|
||||
|
||||
@@ -49,12 +49,16 @@ class TensorStepType(Enum):
|
||||
|
||||
|
||||
class TensorStepBase:
|
||||
def __init__(self, tensor_args: TensorDeviceType) -> None:
|
||||
def __init__(
|
||||
self, tensor_args: TensorDeviceType, batch_size: int = 1, horizon: int = 1
|
||||
) -> None:
|
||||
self.batch_size = -1
|
||||
self.horizon = -1
|
||||
self.tensor_args = tensor_args
|
||||
self._diag_dt = None
|
||||
self._inv_dt_h = None
|
||||
self.action_horizon = horizon
|
||||
self.update_batch_size(batch_size, horizon)
|
||||
|
||||
def update_dt(self, dt: float):
|
||||
self._dt_h[:] = dt
|
||||
@@ -83,8 +87,14 @@ class TensorStepBase:
|
||||
|
||||
|
||||
class TensorStepAcceleration(TensorStepBase):
|
||||
def __init__(self, tensor_args: TensorDeviceType, dt_h: torch.Tensor) -> None:
|
||||
super().__init__(tensor_args)
|
||||
def __init__(
|
||||
self,
|
||||
tensor_args: TensorDeviceType,
|
||||
dt_h: torch.Tensor,
|
||||
batch_size: int = 1,
|
||||
horizon: int = 1,
|
||||
) -> None:
|
||||
super().__init__(tensor_args, batch_size=batch_size, horizon=horizon)
|
||||
self._dt_h = dt_h
|
||||
self._diag_dt_h = torch.diag(self._dt_h)
|
||||
self._integrate_matrix_pos = None
|
||||
@@ -138,8 +148,13 @@ class TensorStepAcceleration(TensorStepBase):
|
||||
|
||||
|
||||
class TensorStepPositionTeleport(TensorStepBase):
|
||||
def __init__(self, tensor_args: TensorDeviceType) -> None:
|
||||
super().__init__(tensor_args)
|
||||
def __init__(
|
||||
self,
|
||||
tensor_args: TensorDeviceType,
|
||||
batch_size: int = 1,
|
||||
horizon: int = 1,
|
||||
) -> None:
|
||||
super().__init__(tensor_args, batch_size=batch_size, horizon=horizon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -153,8 +168,14 @@ class TensorStepPositionTeleport(TensorStepBase):
|
||||
|
||||
|
||||
class TensorStepPosition(TensorStepBase):
|
||||
def __init__(self, tensor_args: TensorDeviceType, dt_h: torch.Tensor) -> None:
|
||||
super().__init__(tensor_args)
|
||||
def __init__(
|
||||
self,
|
||||
tensor_args: TensorDeviceType,
|
||||
dt_h: torch.Tensor,
|
||||
batch_size: int = 1,
|
||||
horizon: int = 1,
|
||||
) -> None:
|
||||
super().__init__(tensor_args, batch_size=batch_size, horizon=horizon)
|
||||
|
||||
self._dt_h = dt_h
|
||||
# self._diag_dt_h = torch.diag(1 / self._dt_h)
|
||||
@@ -185,7 +206,6 @@ class TensorStepPosition(TensorStepBase):
|
||||
)
|
||||
self._fd_matrix = torch.diag(1.0 / self._dt_h) @ self._fd_matrix
|
||||
|
||||
# self._fd_matrix = self._diag_dt_h @ self._fd_matrix
|
||||
return super().update_batch_size(batch_size, horizon)
|
||||
|
||||
def forward(
|
||||
@@ -205,8 +225,14 @@ class TensorStepPosition(TensorStepBase):
|
||||
|
||||
|
||||
class TensorStepPositionClique(TensorStepBase):
|
||||
def __init__(self, tensor_args: TensorDeviceType, dt_h: torch.Tensor) -> None:
|
||||
super().__init__(tensor_args)
|
||||
def __init__(
|
||||
self,
|
||||
tensor_args: TensorDeviceType,
|
||||
dt_h: torch.Tensor,
|
||||
batch_size: int = 1,
|
||||
horizon: int = 1,
|
||||
) -> None:
|
||||
super().__init__(tensor_args, batch_size=batch_size, horizon=horizon)
|
||||
|
||||
self._dt_h = dt_h
|
||||
self._inv_dt_h = 1.0 / dt_h
|
||||
@@ -281,12 +307,20 @@ class TensorStepPositionClique(TensorStepBase):
|
||||
|
||||
|
||||
class TensorStepAccelerationKernel(TensorStepBase):
|
||||
def __init__(self, tensor_args: TensorDeviceType, dt_h: torch.Tensor, dof: int) -> None:
|
||||
super().__init__(tensor_args)
|
||||
def __init__(
|
||||
self,
|
||||
tensor_args: TensorDeviceType,
|
||||
dt_h: torch.Tensor,
|
||||
dof: int,
|
||||
batch_size: int = 1,
|
||||
horizon: int = 1,
|
||||
) -> None:
|
||||
self.dof = dof
|
||||
|
||||
super().__init__(tensor_args, batch_size=batch_size, horizon=horizon)
|
||||
|
||||
self._dt_h = dt_h
|
||||
self._u_grad = None
|
||||
self.dof = dof
|
||||
|
||||
def update_batch_size(
|
||||
self,
|
||||
@@ -363,13 +397,15 @@ class TensorStepPositionCliqueKernel(TensorStepBase):
|
||||
filter_velocity: bool = False,
|
||||
filter_acceleration: bool = False,
|
||||
filter_jerk: bool = False,
|
||||
batch_size: int = 1,
|
||||
horizon: int = 1,
|
||||
) -> None:
|
||||
super().__init__(tensor_args)
|
||||
self.dof = dof
|
||||
self._fd_mode = finite_difference_mode
|
||||
super().__init__(tensor_args, batch_size=batch_size, horizon=horizon)
|
||||
self._dt_h = dt_h
|
||||
self._inv_dt_h = 1.0 / dt_h
|
||||
self._u_grad = None
|
||||
self.dof = dof
|
||||
self._fd_mode = finite_difference_mode
|
||||
self._filter_velocity = filter_velocity
|
||||
self._filter_acceleration = filter_acceleration
|
||||
self._filter_jerk = filter_jerk
|
||||
@@ -381,10 +417,6 @@ class TensorStepPositionCliqueKernel(TensorStepBase):
|
||||
weights = kernel
|
||||
self._sma_kernel = weights
|
||||
|
||||
# self._sma = torch.nn.AvgPool1d(kernel_size=5, stride=2, padding=1).to(
|
||||
# device=self.tensor_args.device
|
||||
# )
|
||||
|
||||
def update_batch_size(
|
||||
self,
|
||||
batch_size: Optional[int] = None,
|
||||
@@ -392,8 +424,11 @@ class TensorStepPositionCliqueKernel(TensorStepBase):
|
||||
force_update: bool = False,
|
||||
) -> None:
|
||||
if batch_size != self.batch_size or horizon != self.horizon:
|
||||
self.action_horizon = horizon
|
||||
if self._fd_mode == 0:
|
||||
self.action_horizon = horizon - 4
|
||||
self._u_grad = torch.zeros(
|
||||
(batch_size, horizon, self.dof),
|
||||
(batch_size, self.action_horizon, self.dof),
|
||||
device=self.tensor_args.device,
|
||||
dtype=self.tensor_args.dtype,
|
||||
)
|
||||
|
||||
@@ -222,6 +222,21 @@ class Goal(Sequence):
|
||||
links_goal_pose=links_goal_pose,
|
||||
)
|
||||
|
||||
def clone(self):
|
||||
return Goal(
|
||||
goal_state=self.goal_state,
|
||||
goal_pose=self.goal_pose,
|
||||
current_state=self.current_state,
|
||||
retract_state=self.retract_state,
|
||||
batch_pose_idx=self.batch_pose_idx,
|
||||
batch_world_idx=self.batch_world_idx,
|
||||
batch_enable_idx=self.batch_enable_idx,
|
||||
batch_current_state_idx=self.batch_current_state_idx,
|
||||
batch_retract_state_idx=self.batch_retract_state_idx,
|
||||
batch_goal_state_idx=self.batch_goal_state_idx,
|
||||
links_goal_pose=self.links_goal_pose,
|
||||
)
|
||||
|
||||
def _tensor_repeat_seeds(self, tensor, num_seeds):
|
||||
return tensor_repeat_seeds(tensor, num_seeds)
|
||||
|
||||
@@ -498,6 +513,10 @@ class RolloutBase:
|
||||
def horizon(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def action_horizon(self) -> int:
|
||||
return self.horizon
|
||||
|
||||
def update_start_state(self, start_state: torch.Tensor):
|
||||
if self.start_state is None:
|
||||
self.start_state = start_state
|
||||
|
||||
Reference in New Issue
Block a user