constrained planning, robot segmentation
This commit is contained in:
204
src/curobo/wrap/model/robot_segmenter.py
Normal file
204
src/curobo/wrap/model/robot_segmenter.py
Normal file
@@ -0,0 +1,204 @@
|
||||
#
|
||||
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
#
|
||||
|
||||
# Standard Library
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
# Third Party
|
||||
import torch
|
||||
from torch.profiler import record_function
|
||||
|
||||
# CuRobo
|
||||
from curobo.geom.cv import (
|
||||
get_projection_rays,
|
||||
project_depth_using_rays,
|
||||
)
|
||||
from curobo.geom.types import PointCloud
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.camera import CameraObservation
|
||||
from curobo.types.math import Pose
|
||||
from curobo.types.robot import RobotConfig
|
||||
from curobo.types.state import JointState
|
||||
from curobo.util.logger import log_error
|
||||
from curobo.util_file import get_robot_configs_path, join_path, load_yaml
|
||||
from curobo.wrap.model.robot_world import RobotWorld, RobotWorldConfig
|
||||
|
||||
|
||||
class RobotSegmenter:
|
||||
def __init__(
|
||||
self,
|
||||
robot_world: RobotWorld,
|
||||
distance_threshold: float = 0.05,
|
||||
use_cuda_graph: bool = True,
|
||||
):
|
||||
self._robot_world = robot_world
|
||||
self._projection_rays = None
|
||||
self.ready = False
|
||||
self._out_points_buffer = None
|
||||
self._out_gp = None
|
||||
self._out_gq = None
|
||||
self._out_gpt = None
|
||||
self._cu_graph = None
|
||||
self._use_cuda_graph = use_cuda_graph
|
||||
self.tensor_args = robot_world.tensor_args
|
||||
self.distance_threshold = distance_threshold
|
||||
|
||||
@staticmethod
|
||||
def from_robot_file(
|
||||
robot_file: Union[str, Dict],
|
||||
collision_sphere_buffer: Optional[float],
|
||||
distance_threshold: float = 0.05,
|
||||
use_cuda_graph: bool = True,
|
||||
tensor_args: TensorDeviceType = TensorDeviceType(),
|
||||
):
|
||||
robot_dict = load_yaml(join_path(get_robot_configs_path(), robot_file))["robot_cfg"]
|
||||
if collision_sphere_buffer is not None:
|
||||
robot_dict["kinematics"]["collision_sphere_buffer"] = collision_sphere_buffer
|
||||
|
||||
robot_cfg = RobotConfig.from_dict(robot_dict, tensor_args=tensor_args)
|
||||
|
||||
config = RobotWorldConfig.load_from_config(
|
||||
robot_cfg,
|
||||
None,
|
||||
collision_activation_distance=0.0,
|
||||
tensor_args=tensor_args,
|
||||
)
|
||||
robot_world = RobotWorld(config)
|
||||
|
||||
return RobotSegmenter(
|
||||
robot_world, distance_threshold=distance_threshold, use_cuda_graph=use_cuda_graph
|
||||
)
|
||||
|
||||
def get_pointcloud_from_depth(self, camera_obs: CameraObservation):
|
||||
if self._projection_rays is None:
|
||||
self.update_camera_projection(camera_obs)
|
||||
depth_image = camera_obs.depth_image
|
||||
if len(depth_image.shape) == 2:
|
||||
depth_image = depth_image.unsqueeze(0)
|
||||
points = project_depth_using_rays(depth_image, self._projection_rays)
|
||||
|
||||
return points
|
||||
|
||||
def update_camera_projection(self, camera_obs: CameraObservation):
|
||||
intrinsics = camera_obs.intrinsics
|
||||
if len(intrinsics.shape) == 2:
|
||||
intrinsics = intrinsics.unsqueeze(0)
|
||||
project_rays = get_projection_rays(
|
||||
camera_obs.depth_image.shape[-2], camera_obs.depth_image.shape[-1], intrinsics
|
||||
)
|
||||
|
||||
if self._projection_rays is None:
|
||||
self._projection_rays = project_rays
|
||||
|
||||
self._projection_rays.copy_(project_rays)
|
||||
self.ready = True
|
||||
|
||||
@record_function("robot_segmenter/get_robot_mask")
|
||||
def get_robot_mask(
|
||||
self,
|
||||
camera_obs: CameraObservation,
|
||||
joint_state: JointState,
|
||||
):
|
||||
"""
|
||||
Assumes 1 robot and batch of depth images, batch of poses
|
||||
"""
|
||||
if len(camera_obs.depth_image.shape) != 3:
|
||||
log_error("Send depth image as (batch, height, width)")
|
||||
|
||||
active_js = self._robot_world.get_active_js(joint_state)
|
||||
|
||||
mask, filtered_image = self.get_robot_mask_from_active_js(camera_obs, active_js)
|
||||
|
||||
return mask, filtered_image
|
||||
|
||||
def get_robot_mask_from_active_js(
|
||||
self, camera_obs: CameraObservation, active_joint_state: JointState
|
||||
):
|
||||
q = active_joint_state.position
|
||||
mask, filtered_image = self._call_op(camera_obs, q)
|
||||
|
||||
return mask, filtered_image
|
||||
|
||||
def _create_cg_graph(self, cam_obs, q):
|
||||
self._cu_cam_obs = cam_obs.clone()
|
||||
self._cu_q = q.clone()
|
||||
s = torch.cuda.Stream(device=self.tensor_args.device)
|
||||
s.wait_stream(torch.cuda.current_stream(device=self.tensor_args.device))
|
||||
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(3):
|
||||
self._cu_out, self._cu_filtered_out = self._mask_op(self._cu_cam_obs, self._cu_q)
|
||||
torch.cuda.current_stream(device=self.tensor_args.device).wait_stream(s)
|
||||
|
||||
self._cu_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self._cu_graph, stream=s):
|
||||
self._cu_out, self._cu_filtered_out = self._mask_op(
|
||||
self._cu_cam_obs,
|
||||
self._cu_q,
|
||||
)
|
||||
|
||||
def _call_op(self, cam_obs, q):
|
||||
if self._use_cuda_graph:
|
||||
if self._cu_graph is None:
|
||||
self._create_cg_graph(cam_obs, q)
|
||||
self._cu_cam_obs.copy_(cam_obs)
|
||||
self._cu_q.copy_(q)
|
||||
self._cu_graph.replay()
|
||||
return self._cu_out.clone(), self._cu_filtered_out.clone()
|
||||
return self._mask_op(cam_obs, q)
|
||||
|
||||
@record_function("robot_segmenter/_mask_op")
|
||||
def _mask_op(self, camera_obs, q):
|
||||
if len(q.shape) == 1:
|
||||
q = q.unsqueeze(0)
|
||||
points = self.get_pointcloud_from_depth(camera_obs)
|
||||
camera_to_robot = camera_obs.pose
|
||||
|
||||
if self._out_points_buffer is None:
|
||||
self._out_points_buffer = points.clone()
|
||||
if self._out_gpt is None:
|
||||
self._out_gpt = torch.zeros((points.shape[0], points.shape[1], 3), device=points.device)
|
||||
if self._out_gp is None:
|
||||
self._out_gp = torch.zeros((camera_to_robot.position.shape[0], 3), device=points.device)
|
||||
if self._out_gq is None:
|
||||
self._out_gq = torch.zeros(
|
||||
(camera_to_robot.quaternion.shape[0], 4), device=points.device
|
||||
)
|
||||
|
||||
points_in_robot_frame = camera_to_robot.batch_transform_points(
|
||||
points,
|
||||
out_buffer=self._out_points_buffer,
|
||||
gp_out=self._out_gp,
|
||||
gq_out=self._out_gq,
|
||||
gpt_out=self._out_gpt,
|
||||
)
|
||||
|
||||
out_points = points_in_robot_frame
|
||||
|
||||
dist = self._robot_world.get_point_robot_distance(out_points, q)
|
||||
|
||||
mask, filtered_image = mask_image(camera_obs.depth_image, dist, self.distance_threshold)
|
||||
|
||||
return mask, filtered_image
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def mask_image(
|
||||
image: torch.Tensor, distance: torch.Tensor, distance_threshold: float
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
distance = distance.view(
|
||||
image.shape[0],
|
||||
image.shape[1],
|
||||
image.shape[2],
|
||||
)
|
||||
mask = torch.logical_and((image > 0.0), (distance > -distance_threshold))
|
||||
filtered_image = torch.where(mask, 0, image)
|
||||
return mask, filtered_image
|
||||
@@ -34,6 +34,7 @@ from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.math import Pose
|
||||
from curobo.types.robot import RobotConfig
|
||||
from curobo.types.state import JointState
|
||||
from curobo.util.logger import log_error
|
||||
from curobo.util.sample_lib import HaltonGenerator
|
||||
from curobo.util.warp import init_warp
|
||||
from curobo.util_file import get_robot_configs_path, get_world_configs_path, join_path, load_yaml
|
||||
@@ -60,8 +61,8 @@ class RobotWorldConfig:
|
||||
world_model: Union[None, str, Dict, WorldConfig, List[WorldConfig], List[str]] = None,
|
||||
tensor_args: TensorDeviceType = TensorDeviceType(),
|
||||
n_envs: int = 1,
|
||||
n_meshes: int = 10,
|
||||
n_cuboids: int = 10,
|
||||
n_meshes: int = 50,
|
||||
n_cuboids: int = 50,
|
||||
collision_activation_distance: float = 0.2,
|
||||
self_collision_activation_distance: float = 0.0,
|
||||
max_collision_distance: float = 1.0,
|
||||
@@ -74,6 +75,8 @@ class RobotWorldConfig:
|
||||
if isinstance(robot_config, str):
|
||||
robot_config = load_yaml(join_path(get_robot_configs_path(), robot_config))["robot_cfg"]
|
||||
if isinstance(robot_config, Dict):
|
||||
if "robot_cfg" in robot_config:
|
||||
robot_config = robot_config["robot_cfg"]
|
||||
robot_config = RobotConfig.from_dict(robot_config, tensor_args)
|
||||
kinematics = CudaRobotModel(robot_config.kinematics)
|
||||
|
||||
@@ -178,8 +181,11 @@ class RobotWorld(RobotWorldConfig):
|
||||
def __init__(self, config: RobotWorldConfig) -> None:
|
||||
RobotWorldConfig.__init__(self, **vars(config))
|
||||
self._batch_pose_idx = None
|
||||
self._camera_projection_rays = None
|
||||
|
||||
def get_kinematics(self, q: torch.Tensor) -> CudaRobotModelState:
|
||||
if len(q.shape) == 1:
|
||||
log_error("q should be of shape [b, dof]")
|
||||
state = self.kinematics.get_state(q)
|
||||
return state
|
||||
|
||||
@@ -344,6 +350,37 @@ class RobotWorld(RobotWorldConfig):
|
||||
distance = self.pose_cost.forward_pose(x_des, x_current, self._batch_pose_idx)
|
||||
return distance
|
||||
|
||||
def get_point_robot_distance(self, points: torch.Tensor, q: torch.Tensor):
|
||||
"""Compute distance from the robot at q joint configuration to points (e.g., pointcloud)
|
||||
|
||||
Args:
|
||||
points: [b,n,3]
|
||||
q: [1, dof]
|
||||
|
||||
Returns:
|
||||
distance: [b,1] Positive is in collision with robot
|
||||
NOTE: This currently does not support batched robot but can be done easily.
|
||||
"""
|
||||
if len(q.shape) == 1:
|
||||
log_error("q should be of shape [b, dof]")
|
||||
kin_state = self.get_kinematics(q)
|
||||
b, n = None, None
|
||||
if len(points.shape) == 3:
|
||||
b, n, _ = points.shape
|
||||
points = points.view(b * n, 3)
|
||||
|
||||
pt_distance = point_robot_distance(kin_state.link_spheres_tensor, points)
|
||||
|
||||
if b is not None:
|
||||
pt_distance = pt_distance.view(b, n)
|
||||
|
||||
return pt_distance
|
||||
|
||||
def get_active_js(self, full_js: JointState):
|
||||
active_jnames = self.kinematics.joint_names
|
||||
out_js = full_js.get_ordered_joint_state(active_jnames)
|
||||
return out_js
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def sum_mask(d1, d2, d3):
|
||||
@@ -357,3 +394,15 @@ def mask(d1, d2, d3):
|
||||
d_total = d1 + d2 + d3
|
||||
d_mask = d_total == 0.0
|
||||
return d_mask
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def point_robot_distance(link_spheres_tensor, points):
|
||||
robot_spheres = link_spheres_tensor.view(1, -1, 4).contiguous()
|
||||
robot_radius = robot_spheres[:, :, 3]
|
||||
points = points.unsqueeze(1)
|
||||
sph_distance = (
|
||||
torch.linalg.norm(points - robot_spheres[:, :, :3], dim=-1) - robot_radius
|
||||
) # b, n_spheres
|
||||
pt_distance = torch.max(-1 * sph_distance, dim=-1)[0]
|
||||
return pt_distance
|
||||
|
||||
@@ -52,7 +52,8 @@ def smooth_cost(abs_acc, abs_jerk, opt_dt):
|
||||
# jerk = torch.max(torch.max(abs_jerk, dim=-1)[0], dim=-1)[0]
|
||||
jerk = torch.mean(torch.max(abs_jerk, dim=-1)[0], dim=-1)
|
||||
mean_acc = torch.mean(torch.max(abs_acc, dim=-1)[0], dim=-1) # [0]
|
||||
a = (jerk * 0.001) + opt_dt + (mean_acc * 0.01)
|
||||
a = (jerk * 0.001) + 5.0 * opt_dt + (mean_acc * 0.01)
|
||||
|
||||
return a
|
||||
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from curobo.opt.newton.newton_base import NewtonOptBase, NewtonOptConfig
|
||||
from curobo.opt.particle.parallel_es import ParallelES, ParallelESConfig
|
||||
from curobo.opt.particle.parallel_mppi import ParallelMPPI, ParallelMPPIConfig
|
||||
from curobo.rollout.arm_reacher import ArmReacher, ArmReacherConfig
|
||||
from curobo.rollout.cost.pose_cost import PoseCostMetric
|
||||
from curobo.rollout.rollout_base import Goal, RolloutBase, RolloutMetrics
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.math import Pose
|
||||
@@ -56,6 +57,7 @@ class IKSolverConfig:
|
||||
world_coll_checker: Optional[WorldCollision] = None
|
||||
sample_rejection_ratio: int = 50
|
||||
tensor_args: TensorDeviceType = TensorDeviceType()
|
||||
use_cuda_graph: bool = True
|
||||
|
||||
@staticmethod
|
||||
@profiler.record_function("ik_solver/load_from_robot_config")
|
||||
@@ -72,12 +74,12 @@ class IKSolverConfig:
|
||||
base_cfg_file: str = "base_cfg.yml",
|
||||
particle_file: str = "particle_ik.yml",
|
||||
gradient_file: str = "gradient_ik.yml",
|
||||
use_cuda_graph: Optional[bool] = None,
|
||||
use_cuda_graph: bool = True,
|
||||
self_collision_check: bool = True,
|
||||
self_collision_opt: bool = True,
|
||||
grad_iters: Optional[int] = None,
|
||||
use_particle_opt: bool = True,
|
||||
collision_checker_type: Optional[CollisionCheckerType] = CollisionCheckerType.PRIMITIVE,
|
||||
collision_checker_type: Optional[CollisionCheckerType] = CollisionCheckerType.MESH,
|
||||
sync_cuda_time: Optional[bool] = None,
|
||||
use_gradient_descent: bool = False,
|
||||
collision_cache: Optional[Dict[str, int]] = None,
|
||||
@@ -90,6 +92,7 @@ class IKSolverConfig:
|
||||
regularization: bool = True,
|
||||
collision_activation_distance: Optional[float] = None,
|
||||
high_precision: bool = False,
|
||||
project_pose_to_goal_frame: bool = True,
|
||||
):
|
||||
if position_threshold <= 0.001:
|
||||
high_precision = True
|
||||
@@ -116,6 +119,9 @@ class IKSolverConfig:
|
||||
base_config_data["convergence"]["cspace_cfg"]["weight"] = 0.0
|
||||
config_data["cost"]["bound_cfg"]["null_space_weight"] = 0.0
|
||||
grad_config_data["cost"]["bound_cfg"]["null_space_weight"] = 0.0
|
||||
|
||||
if isinstance(robot_cfg, str):
|
||||
robot_cfg = load_yaml(join_path(get_robot_configs_path(), robot_cfg))["robot_cfg"]
|
||||
if ee_link_name is not None:
|
||||
if isinstance(robot_cfg, RobotConfig):
|
||||
raise NotImplementedError("ee link cannot be changed after creating RobotConfig")
|
||||
@@ -123,8 +129,6 @@ class IKSolverConfig:
|
||||
robot_cfg.kinematics.ee_link = ee_link_name
|
||||
else:
|
||||
robot_cfg["kinematics"]["ee_link"] = ee_link_name
|
||||
if isinstance(robot_cfg, str):
|
||||
robot_cfg = load_yaml(join_path(get_robot_configs_path(), robot_cfg))["robot_cfg"]
|
||||
if isinstance(robot_cfg, dict):
|
||||
robot_cfg = RobotConfig.from_dict(robot_cfg, tensor_args)
|
||||
|
||||
@@ -160,8 +164,8 @@ class IKSolverConfig:
|
||||
grad_config_data["cost"]["self_collision_cfg"]["weight"] = 0.0
|
||||
if grad_iters is not None:
|
||||
grad_config_data["lbfgs"]["n_iters"] = grad_iters
|
||||
config_data["mppi"]["n_envs"] = 1
|
||||
grad_config_data["lbfgs"]["n_envs"] = 1
|
||||
config_data["mppi"]["n_problems"] = 1
|
||||
grad_config_data["lbfgs"]["n_problems"] = 1
|
||||
grad_cfg = ArmReacherConfig.from_dict(
|
||||
robot_cfg,
|
||||
grad_config_data["model"],
|
||||
@@ -241,6 +245,7 @@ class IKSolverConfig:
|
||||
world_coll_checker=world_coll_checker,
|
||||
rollout_fn=aux_rollout,
|
||||
tensor_args=tensor_args,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
)
|
||||
return ik_cfg
|
||||
|
||||
@@ -259,6 +264,7 @@ class IKResult(Sequence):
|
||||
error: T_BValue_float
|
||||
solve_time: float
|
||||
debug_info: Optional[Any] = None
|
||||
goalset_index: Optional[torch.Tensor] = None
|
||||
|
||||
def __getitem__(self, idx):
|
||||
success = self.success[idx]
|
||||
@@ -272,6 +278,7 @@ class IKResult(Sequence):
|
||||
position_error=self.position_error[idx],
|
||||
rotation_error=self.rotation_error[idx],
|
||||
debug_info=self.debug_info,
|
||||
goalset_index=None if self.goalset_index is None else self.goalset_index[idx],
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
@@ -317,7 +324,7 @@ class IKSolver(IKSolverConfig):
|
||||
self.solver.rollout_fn.retract_state.unsqueeze(0)
|
||||
)
|
||||
self.dof = self.solver.safety_rollout.d_action
|
||||
self._col = torch.arange(0, 1, device=self.tensor_args.device, dtype=torch.long)
|
||||
self._col = None # torch.arange(0, 1, device=self.tensor_args.device, dtype=torch.long)
|
||||
|
||||
# self.fixed_seeds = self.solver.safety_rollout.sample_random_actions(100 * 200)
|
||||
# create random seeder:
|
||||
@@ -355,10 +362,14 @@ class IKSolver(IKSolverConfig):
|
||||
self._goal_buffer,
|
||||
self.tensor_args,
|
||||
)
|
||||
# print("Goal:", self._goal_buffer.links_goal_pose)
|
||||
|
||||
if update_reference:
|
||||
self.solver.update_nenvs(self._solve_state.get_ik_batch_size())
|
||||
self.reset_cuda_graph()
|
||||
self.reset_shape()
|
||||
if self.use_cuda_graph and self._col is not None:
|
||||
log_error("changing goal type, breaking previous cuda graph.")
|
||||
self.reset_cuda_graph()
|
||||
|
||||
self.solver.update_nproblems(self._solve_state.get_ik_batch_size())
|
||||
self._goal_buffer.current_state = self.init_state.repeat_seeds(goal_pose.batch)
|
||||
self._col = torch.arange(
|
||||
0,
|
||||
@@ -676,6 +687,8 @@ class IKSolver(IKSolverConfig):
|
||||
if newton_iters is not None:
|
||||
self.solver.newton_optimizer.outer_iters = self.og_newton_iters
|
||||
ik_result = self.get_result(num_seeds, result, goal_buffer.goal_pose, return_seeds)
|
||||
if ik_result.goalset_index is not None:
|
||||
ik_result.goalset_index[ik_result.goalset_index >= goal_pose.n_goalset] = 0
|
||||
|
||||
return ik_result
|
||||
|
||||
@@ -684,15 +697,18 @@ class IKSolver(IKSolverConfig):
|
||||
self, num_seeds: int, result: WrapResult, goal_pose: Pose, return_seeds: int
|
||||
) -> IKResult:
|
||||
success = self.get_success(result.metrics, num_seeds=num_seeds)
|
||||
if result.metrics.cost is not None:
|
||||
result.metrics.pose_error += result.metrics.cost
|
||||
# if result.metrics.cost is not None:
|
||||
# result.metrics.pose_error += result.metrics.cost * 0.0001
|
||||
if result.metrics.null_space_error is not None:
|
||||
result.metrics.pose_error += result.metrics.null_space_error
|
||||
if result.metrics.cspace_error is not None:
|
||||
result.metrics.pose_error += result.metrics.cspace_error
|
||||
|
||||
q_sol, success, position_error, rotation_error, total_error = get_result(
|
||||
q_sol, success, position_error, rotation_error, total_error, goalset_index = get_result(
|
||||
result.metrics.pose_error,
|
||||
result.metrics.position_error,
|
||||
result.metrics.rotation_error,
|
||||
result.metrics.goalset_index,
|
||||
success,
|
||||
result.action.position,
|
||||
self._col,
|
||||
@@ -717,6 +733,7 @@ class IKSolver(IKSolverConfig):
|
||||
solve_time=result.solve_time,
|
||||
error=total_error,
|
||||
debug_info={"solver": result.debug},
|
||||
goalset_index=goalset_index,
|
||||
)
|
||||
return ik_result
|
||||
|
||||
@@ -959,6 +976,10 @@ class IKSolver(IKSolverConfig):
|
||||
self.solver.reset_cuda_graph()
|
||||
self.rollout_fn.reset_cuda_graph()
|
||||
|
||||
def reset_shape(self):
|
||||
self.solver.reset_shape()
|
||||
self.rollout_fn.reset_shape()
|
||||
|
||||
def attach_object_to_robot(
|
||||
self,
|
||||
sphere_radius: float,
|
||||
@@ -977,6 +998,17 @@ class IKSolver(IKSolverConfig):
|
||||
def get_retract_config(self):
|
||||
return self.rollout_fn.dynamics_model.retract_config
|
||||
|
||||
def update_pose_cost_metric(
|
||||
self,
|
||||
metric: PoseCostMetric,
|
||||
):
|
||||
rollouts = self.get_all_rollout_instances()
|
||||
[
|
||||
rollout.update_pose_cost_metric(metric)
|
||||
for rollout in rollouts
|
||||
if isinstance(rollout, ArmReacher)
|
||||
]
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def get_success(
|
||||
@@ -1001,6 +1033,7 @@ def get_result(
|
||||
pose_error,
|
||||
position_error,
|
||||
rotation_error,
|
||||
goalset_index: Union[torch.Tensor, None],
|
||||
success,
|
||||
sol_position,
|
||||
col,
|
||||
@@ -1018,4 +1051,6 @@ def get_result(
|
||||
position_error = position_error[idx].view(batch_size, return_seeds)
|
||||
rotation_error = rotation_error[idx].view(batch_size, return_seeds)
|
||||
total_error = position_error + rotation_error
|
||||
return q_sol, success, position_error, rotation_error, total_error
|
||||
if goalset_index is not None:
|
||||
goalset_index = goalset_index[idx].view(batch_size, return_seeds)
|
||||
return q_sol, success, position_error, rotation_error, total_error, goalset_index
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -80,7 +80,7 @@ class MpcSolverConfig:
|
||||
|
||||
task_file = "particle_mpc.yml"
|
||||
config_data = load_yaml(join_path(get_task_configs_path(), task_file))
|
||||
config_data["mppi"]["n_envs"] = 1
|
||||
config_data["mppi"]["n_problems"] = 1
|
||||
if step_dt is not None:
|
||||
config_data["model"]["dt_traj_params"]["base_dt"] = step_dt
|
||||
if particle_opt_iters is not None:
|
||||
@@ -238,7 +238,7 @@ class MpcSolver(MpcSolverConfig):
|
||||
self.tensor_args,
|
||||
)
|
||||
if update_reference:
|
||||
self.solver.update_nenvs(self._solve_state.get_batch_size())
|
||||
self.solver.update_nproblems(self._solve_state.get_batch_size())
|
||||
self.reset()
|
||||
self.reset_cuda_graph()
|
||||
self._col = torch.arange(
|
||||
|
||||
@@ -30,6 +30,7 @@ from curobo.opt.newton.newton_base import NewtonOptBase, NewtonOptConfig
|
||||
from curobo.opt.particle.parallel_es import ParallelES, ParallelESConfig
|
||||
from curobo.opt.particle.parallel_mppi import ParallelMPPI, ParallelMPPIConfig
|
||||
from curobo.rollout.arm_reacher import ArmReacher, ArmReacherConfig
|
||||
from curobo.rollout.cost.pose_cost import PoseCostMetric
|
||||
from curobo.rollout.dynamics_model.integration_utils import (
|
||||
action_interpolate_kernel,
|
||||
interpolate_kernel,
|
||||
@@ -39,7 +40,7 @@ from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.robot import JointState, RobotConfig
|
||||
from curobo.types.tensor import T_BDOF, T_DOF, T_BValue_bool, T_BValue_float
|
||||
from curobo.util.helpers import list_idx_if_not_none
|
||||
from curobo.util.logger import log_info, log_warn
|
||||
from curobo.util.logger import log_error, log_info, log_warn
|
||||
from curobo.util.trajectory import (
|
||||
InterpolateType,
|
||||
calculate_dt_no_clamp,
|
||||
@@ -78,6 +79,7 @@ class TrajOptSolverConfig:
|
||||
trim_steps: Optional[List[int]] = None
|
||||
store_debug_in_result: bool = False
|
||||
optimize_dt: bool = True
|
||||
use_cuda_graph: bool = True
|
||||
|
||||
@staticmethod
|
||||
@profiler.record_function("trajopt_config/load_from_robot_config")
|
||||
@@ -98,14 +100,14 @@ class TrajOptSolverConfig:
|
||||
interpolation_type: InterpolateType = InterpolateType.LINEAR_CUDA,
|
||||
interpolation_steps: int = 10000,
|
||||
interpolation_dt: float = 0.01,
|
||||
use_cuda_graph: Optional[bool] = None,
|
||||
use_cuda_graph: bool = True,
|
||||
self_collision_check: bool = False,
|
||||
self_collision_opt: bool = True,
|
||||
grad_trajopt_iters: Optional[int] = None,
|
||||
num_seeds: int = 2,
|
||||
seed_ratio: Dict[str, int] = {"linear": 1.0, "bias": 0.0, "start": 0.0, "end": 0.0},
|
||||
use_particle_opt: bool = True,
|
||||
collision_checker_type: Optional[CollisionCheckerType] = CollisionCheckerType.PRIMITIVE,
|
||||
collision_checker_type: Optional[CollisionCheckerType] = CollisionCheckerType.MESH,
|
||||
traj_evaluator_config: TrajEvaluatorConfig = TrajEvaluatorConfig(),
|
||||
traj_evaluator: Optional[TrajEvaluator] = None,
|
||||
minimize_jerk: bool = True,
|
||||
@@ -128,6 +130,7 @@ class TrajOptSolverConfig:
|
||||
state_finite_difference_mode: Optional[str] = None,
|
||||
filter_robot_command: bool = False,
|
||||
optimize_dt: bool = True,
|
||||
project_pose_to_goal_frame: bool = True,
|
||||
):
|
||||
# NOTE: Don't have default optimize_dt, instead read from a configuration file.
|
||||
# use default values, disable environment collision checking
|
||||
@@ -163,6 +166,16 @@ class TrajOptSolverConfig:
|
||||
if traj_tsteps is None:
|
||||
traj_tsteps = grad_config_data["model"]["horizon"]
|
||||
|
||||
base_config_data["cost"]["pose_cfg"]["project_distance"] = project_pose_to_goal_frame
|
||||
base_config_data["convergence"]["pose_cfg"]["project_distance"] = project_pose_to_goal_frame
|
||||
config_data["cost"]["pose_cfg"]["project_distance"] = project_pose_to_goal_frame
|
||||
grad_config_data["cost"]["pose_cfg"]["project_distance"] = project_pose_to_goal_frame
|
||||
|
||||
base_config_data["cost"]["pose_cfg"]["project_distance"] = project_pose_to_goal_frame
|
||||
base_config_data["convergence"]["pose_cfg"]["project_distance"] = project_pose_to_goal_frame
|
||||
config_data["cost"]["pose_cfg"]["project_distance"] = project_pose_to_goal_frame
|
||||
grad_config_data["cost"]["pose_cfg"]["project_distance"] = project_pose_to_goal_frame
|
||||
|
||||
config_data["model"]["horizon"] = traj_tsteps
|
||||
grad_config_data["model"]["horizon"] = traj_tsteps
|
||||
if minimize_jerk is not None:
|
||||
@@ -212,8 +225,8 @@ class TrajOptSolverConfig:
|
||||
if not self_collision_opt:
|
||||
config_data["cost"]["self_collision_cfg"]["weight"] = 0.0
|
||||
grad_config_data["cost"]["self_collision_cfg"]["weight"] = 0.0
|
||||
config_data["mppi"]["n_envs"] = 1
|
||||
grad_config_data["lbfgs"]["n_envs"] = 1
|
||||
config_data["mppi"]["n_problems"] = 1
|
||||
grad_config_data["lbfgs"]["n_problems"] = 1
|
||||
|
||||
if fixed_iters is not None:
|
||||
grad_config_data["lbfgs"]["fixed_iters"] = fixed_iters
|
||||
@@ -256,9 +269,10 @@ class TrajOptSolverConfig:
|
||||
world_coll_checker=world_coll_checker,
|
||||
tensor_args=tensor_args,
|
||||
)
|
||||
|
||||
arm_rollout_mppi = None
|
||||
with profiler.record_function("trajopt_config/create_rollouts"):
|
||||
arm_rollout_mppi = ArmReacher(cfg)
|
||||
if use_particle_opt:
|
||||
arm_rollout_mppi = ArmReacher(cfg)
|
||||
arm_rollout_grad = ArmReacher(grad_cfg)
|
||||
|
||||
arm_rollout_safety = ArmReacher(safety_cfg)
|
||||
@@ -266,20 +280,22 @@ class TrajOptSolverConfig:
|
||||
aux_rollout = ArmReacher(safety_cfg)
|
||||
interpolate_rollout = ArmReacher(safety_cfg)
|
||||
if trajopt_dt is not None:
|
||||
arm_rollout_mppi.update_traj_dt(dt=trajopt_dt)
|
||||
if arm_rollout_mppi is not None:
|
||||
arm_rollout_mppi.update_traj_dt(dt=trajopt_dt)
|
||||
aux_rollout.update_traj_dt(dt=trajopt_dt)
|
||||
arm_rollout_grad.update_traj_dt(dt=trajopt_dt)
|
||||
arm_rollout_safety.update_traj_dt(dt=trajopt_dt)
|
||||
|
||||
config_dict = ParallelMPPIConfig.create_data_dict(
|
||||
config_data["mppi"], arm_rollout_mppi, tensor_args
|
||||
)
|
||||
if arm_rollout_mppi is not None:
|
||||
config_dict = ParallelMPPIConfig.create_data_dict(
|
||||
config_data["mppi"], arm_rollout_mppi, tensor_args
|
||||
)
|
||||
parallel_mppi = None
|
||||
if use_es is not None and use_es:
|
||||
mppi_cfg = ParallelESConfig(**config_dict)
|
||||
if es_learning_rate is not None:
|
||||
mppi_cfg.learning_rate = es_learning_rate
|
||||
parallel_mppi = ParallelES(mppi_cfg)
|
||||
else:
|
||||
elif use_particle_opt:
|
||||
mppi_cfg = ParallelMPPIConfig(**config_dict)
|
||||
parallel_mppi = ParallelMPPI(mppi_cfg)
|
||||
|
||||
@@ -307,7 +323,7 @@ class TrajOptSolverConfig:
|
||||
cfg = WrapConfig(
|
||||
safety_rollout=arm_rollout_safety,
|
||||
optimizers=opt_list,
|
||||
compute_metrics=not evaluate_interpolated_trajectory,
|
||||
compute_metrics=True, # not evaluate_interpolated_trajectory,
|
||||
use_cuda_graph_metrics=grad_config_data["lbfgs"]["use_cuda_graph"],
|
||||
sync_cuda_time=sync_cuda_time,
|
||||
)
|
||||
@@ -337,6 +353,7 @@ class TrajOptSolverConfig:
|
||||
trim_steps=trim_steps,
|
||||
store_debug_in_result=store_debug_in_result,
|
||||
optimize_dt=optimize_dt,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
)
|
||||
return trajopt_cfg
|
||||
|
||||
@@ -360,6 +377,7 @@ class TrajResult(Sequence):
|
||||
optimized_dt: Optional[torch.Tensor] = None
|
||||
raw_solution: Optional[JointState] = None
|
||||
raw_action: Optional[torch.Tensor] = None
|
||||
goalset_index: Optional[torch.Tensor] = None
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# position_error = rotation_error = cspace_error = path_buffer_last_tstep = None
|
||||
@@ -372,6 +390,7 @@ class TrajResult(Sequence):
|
||||
self.position_error,
|
||||
self.rotation_error,
|
||||
self.cspace_error,
|
||||
self.goalset_index,
|
||||
]
|
||||
idx_vals = list_idx_if_not_none(d_list, idx)
|
||||
|
||||
@@ -388,6 +407,7 @@ class TrajResult(Sequence):
|
||||
position_error=idx_vals[3],
|
||||
rotation_error=idx_vals[4],
|
||||
cspace_error=idx_vals[5],
|
||||
goalset_index=idx_vals[6],
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
@@ -405,7 +425,7 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
3, int(self.action_horizon / 2), self.tensor_args
|
||||
).unsqueeze(0)
|
||||
assert self.action_horizon / 2 != 0.0
|
||||
self.solver.update_nenvs(self.num_seeds)
|
||||
self.solver.update_nproblems(self.num_seeds)
|
||||
self._max_joint_vel = (
|
||||
self.solver.safety_rollout.state_bounds.velocity.view(2, self.dof)[1, :].reshape(
|
||||
1, 1, self.dof
|
||||
@@ -469,12 +489,18 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
self._goal_buffer,
|
||||
self.tensor_args,
|
||||
)
|
||||
|
||||
if update_reference:
|
||||
self.solver.update_nenvs(self._solve_state.get_batch_size())
|
||||
self.reset_cuda_graph()
|
||||
if self.use_cuda_graph and self._col is not None:
|
||||
log_error("changing goal type, breaking previous cuda graph.")
|
||||
self.reset_cuda_graph()
|
||||
|
||||
self.solver.update_nproblems(self._solve_state.get_batch_size())
|
||||
self._col = torch.arange(
|
||||
0, goal.batch, device=self.tensor_args.device, dtype=torch.long
|
||||
)
|
||||
self.reset_shape()
|
||||
|
||||
return self._goal_buffer
|
||||
|
||||
def solve_any(
|
||||
@@ -586,6 +612,8 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
num_seeds,
|
||||
solve_state.batch_mode,
|
||||
)
|
||||
if traj_result.goalset_index is not None:
|
||||
traj_result.goalset_index[traj_result.goalset_index >= goal.goal_pose.n_goalset] = 0
|
||||
if newton_iters is not None:
|
||||
self.solver.newton_optimizer.outer_iters = self._og_newton_iters
|
||||
self.solver.newton_optimizer.fixed_iters = self._og_newton_fixed_iters
|
||||
@@ -839,16 +867,18 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
if self.evaluate_interpolated_trajectory:
|
||||
with profiler.record_function("trajopt/evaluate_interpolated"):
|
||||
if self.use_cuda_graph_metrics:
|
||||
result.metrics = self.interpolate_rollout.get_metrics_cuda_graph(
|
||||
interpolated_trajs
|
||||
)
|
||||
metrics = self.interpolate_rollout.get_metrics_cuda_graph(interpolated_trajs)
|
||||
else:
|
||||
result.metrics = self.interpolate_rollout.get_metrics(interpolated_trajs)
|
||||
metrics = self.interpolate_rollout.get_metrics(interpolated_trajs)
|
||||
result.metrics.feasible = metrics.feasible
|
||||
result.metrics.position_error = metrics.position_error
|
||||
result.metrics.rotation_error = metrics.rotation_error
|
||||
result.metrics.cspace_error = metrics.cspace_error
|
||||
result.metrics.goalset_index = metrics.goalset_index
|
||||
|
||||
st_time = time.time()
|
||||
feasible = torch.all(result.metrics.feasible, dim=-1)
|
||||
# if self.num_seeds == 1:
|
||||
# print(result.metrics)
|
||||
|
||||
if result.metrics.position_error is not None:
|
||||
converge = torch.logical_and(
|
||||
result.metrics.position_error[..., -1] <= self.position_threshold,
|
||||
@@ -877,10 +907,10 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
optimized_dt=opt_dt,
|
||||
raw_solution=result.action,
|
||||
raw_action=result.raw_action,
|
||||
goalset_index=result.metrics.goalset_index,
|
||||
)
|
||||
else:
|
||||
# get path length:
|
||||
# max_vel =
|
||||
if self.evaluate_interpolated_trajectory:
|
||||
smooth_label, smooth_cost = self.traj_evaluator.evaluate_interpolated_smootheness(
|
||||
interpolated_trajs,
|
||||
@@ -896,7 +926,6 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
self.solver.rollout_fn.dynamics_model.cspace_distance_weight,
|
||||
self._velocity_bounds,
|
||||
)
|
||||
# print(smooth_label, success, self._velocity_bounds.shape, self.solver.rollout_fn.dynamics_model.cspace_distance_weight)
|
||||
|
||||
with profiler.record_function("trajopt/best_select"):
|
||||
success[~smooth_label] = False
|
||||
@@ -907,7 +936,8 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
convergence_error = result.metrics.cspace_error[..., -1]
|
||||
else:
|
||||
raise ValueError("convergence check requires either goal_pose or goal_state")
|
||||
error = convergence_error + smooth_cost
|
||||
running_cost = torch.mean(result.metrics.cost, dim=-1) * 0.0001
|
||||
error = convergence_error + smooth_cost + running_cost
|
||||
error[~success] += 10000.0
|
||||
if batch_mode:
|
||||
idx = torch.argmin(error.view(goal.batch, num_seeds), dim=-1)
|
||||
@@ -923,13 +953,16 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
best_act_seq = result.action[idx]
|
||||
best_raw_action = result.raw_action[idx]
|
||||
interpolated_traj = interpolated_trajs[idx]
|
||||
position_error = rotation_error = cspace_error = None
|
||||
goalset_index = position_error = rotation_error = cspace_error = None
|
||||
if result.metrics.position_error is not None:
|
||||
position_error = result.metrics.position_error[idx, -1]
|
||||
if result.metrics.rotation_error is not None:
|
||||
rotation_error = result.metrics.rotation_error[idx, -1]
|
||||
if result.metrics.cspace_error is not None:
|
||||
cspace_error = result.metrics.cspace_error[idx, -1]
|
||||
if result.metrics.goalset_index is not None:
|
||||
goalset_index = result.metrics.goalset_index[idx, -1]
|
||||
|
||||
opt_dt = opt_dt[idx]
|
||||
if self.sync_cuda_time:
|
||||
torch.cuda.synchronize()
|
||||
@@ -965,6 +998,7 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
optimized_dt=opt_dt,
|
||||
raw_solution=best_act_seq,
|
||||
raw_action=best_raw_action,
|
||||
goalset_index=goalset_index,
|
||||
)
|
||||
return traj_result
|
||||
|
||||
@@ -999,7 +1033,6 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
return self.solve_batch_goalset(
|
||||
goal, seed_traj, use_nn_seed, return_all_solutions, num_seeds, seed_success
|
||||
)
|
||||
return traj_result
|
||||
|
||||
def get_linear_seed(self, start_state, goal_state):
|
||||
start_q = start_state.position.view(-1, 1, self.dof)
|
||||
@@ -1173,6 +1206,11 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
self.interpolate_rollout.reset_cuda_graph()
|
||||
self.rollout_fn.reset_cuda_graph()
|
||||
|
||||
def reset_shape(self):
|
||||
self.solver.reset_shape()
|
||||
self.interpolate_rollout.reset_shape()
|
||||
self.rollout_fn.reset_shape()
|
||||
|
||||
@property
|
||||
def kinematics(self) -> CudaRobotModel:
|
||||
return self.rollout_fn.dynamics_model.robot_model
|
||||
@@ -1205,3 +1243,14 @@ class TrajOptSolver(TrajOptSolverConfig):
|
||||
|
||||
def get_full_js(self, active_js: JointState) -> JointState:
|
||||
return self.rollout_fn.get_full_dof_from_solution(active_js)
|
||||
|
||||
def update_pose_cost_metric(
|
||||
self,
|
||||
metric: PoseCostMetric,
|
||||
):
|
||||
rollouts = self.get_all_rollout_instances()
|
||||
[
|
||||
rollout.update_pose_cost_metric(metric)
|
||||
for rollout in rollouts
|
||||
if isinstance(rollout, ArmReacher)
|
||||
]
|
||||
|
||||
@@ -13,7 +13,7 @@ from __future__ import annotations
|
||||
# Standard Library
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
# CuRobo
|
||||
from curobo.rollout.rollout_base import Goal
|
||||
@@ -126,16 +126,26 @@ class ReacherSolveState:
|
||||
if (
|
||||
current_solve_state is None
|
||||
or current_goal_buffer is None
|
||||
or current_solve_state != solve_state
|
||||
or (current_goal_buffer.retract_state is None and retract_config is not None)
|
||||
or (current_goal_buffer.goal_state is None and goal_state is not None)
|
||||
or (current_goal_buffer.links_goal_pose is None and link_poses is not None)
|
||||
):
|
||||
update_reference = True
|
||||
|
||||
elif current_solve_state != solve_state:
|
||||
new_goal_pose = get_padded_goalset(
|
||||
solve_state, current_solve_state, current_goal_buffer, goal_pose
|
||||
)
|
||||
if new_goal_pose is not None:
|
||||
goal_pose = new_goal_pose
|
||||
else:
|
||||
update_reference = True
|
||||
|
||||
if update_reference:
|
||||
current_solve_state = solve_state
|
||||
current_goal_buffer = solve_state.create_goal_buffer(
|
||||
goal_pose, goal_state, retract_config, link_poses, tensor_args
|
||||
)
|
||||
update_reference = True
|
||||
else:
|
||||
current_goal_buffer.goal_pose.copy_(goal_pose)
|
||||
if retract_config is not None:
|
||||
@@ -145,6 +155,7 @@ class ReacherSolveState:
|
||||
if link_poses is not None:
|
||||
for k in link_poses.keys():
|
||||
current_goal_buffer.links_goal_pose[k].copy_(link_poses[k])
|
||||
|
||||
return current_solve_state, current_goal_buffer, update_reference
|
||||
|
||||
def update_goal(
|
||||
@@ -155,17 +166,26 @@ class ReacherSolveState:
|
||||
tensor_args: TensorDeviceType = TensorDeviceType(),
|
||||
):
|
||||
solve_state = self
|
||||
|
||||
update_reference = False
|
||||
if (
|
||||
current_solve_state is None
|
||||
or current_goal_buffer is None
|
||||
or current_solve_state != solve_state
|
||||
or (current_goal_buffer.goal_state is None and goal.goal_state is not None)
|
||||
or (current_goal_buffer.goal_state is not None and goal.goal_state is None)
|
||||
):
|
||||
# TODO: Check for change in update idx buffers, currently we assume
|
||||
# that solve_state captures difference in idx buffers
|
||||
update_reference = True
|
||||
elif current_solve_state != solve_state:
|
||||
new_goal_pose = get_padded_goalset(
|
||||
solve_state, current_solve_state, current_goal_buffer, goal.goal_pose
|
||||
)
|
||||
if new_goal_pose is not None:
|
||||
goal = goal.clone()
|
||||
goal.goal_pose = new_goal_pose
|
||||
|
||||
else:
|
||||
update_reference = True
|
||||
|
||||
if update_reference:
|
||||
current_solve_state = solve_state
|
||||
current_goal_buffer = goal.create_index_buffers(
|
||||
solve_state.batch_size,
|
||||
@@ -174,7 +194,6 @@ class ReacherSolveState:
|
||||
solve_state.num_seeds,
|
||||
tensor_args,
|
||||
)
|
||||
update_reference = True
|
||||
else:
|
||||
current_goal_buffer.copy_(goal, update_idx_buffers=False)
|
||||
return current_solve_state, current_goal_buffer, update_reference
|
||||
@@ -185,3 +204,92 @@ class MotionGenSolverState:
|
||||
solve_type: ReacherSolveType
|
||||
ik_solve_state: ReacherSolveState
|
||||
trajopt_solve_state: ReacherSolveState
|
||||
|
||||
|
||||
def get_padded_goalset(
|
||||
solve_state: ReacherSolveState,
|
||||
current_solve_state: ReacherSolveState,
|
||||
current_goal_buffer: Goal,
|
||||
new_goal_pose: Pose,
|
||||
) -> Union[Pose, None]:
|
||||
if (
|
||||
current_solve_state.solve_type == ReacherSolveType.GOALSET
|
||||
and solve_state.solve_type == ReacherSolveType.SINGLE
|
||||
):
|
||||
# convert single goal to goal set
|
||||
# solve_state.solve_type = ReacherSolveType.GOALSET
|
||||
# solve_state.n_goalset = current_solve_state.n_goalset
|
||||
|
||||
goal_pose = current_goal_buffer.goal_pose.clone()
|
||||
goal_pose.position[:] = new_goal_pose.position
|
||||
goal_pose.quaternion[:] = new_goal_pose.quaternion
|
||||
return goal_pose
|
||||
|
||||
elif (
|
||||
current_solve_state.solve_type == ReacherSolveType.BATCH_GOALSET
|
||||
and solve_state.solve_type == ReacherSolveType.BATCH
|
||||
and new_goal_pose.n_goalset <= current_solve_state.n_goalset
|
||||
and new_goal_pose.batch == current_solve_state.batch_size
|
||||
):
|
||||
goal_pose = current_goal_buffer.goal_pose.clone()
|
||||
if len(new_goal_pose.position.shape) == 2:
|
||||
new_goal_pose = new_goal_pose.unsqueeze(1)
|
||||
goal_pose.position[..., :, :] = new_goal_pose.position
|
||||
goal_pose.quaternion[..., :, :] = new_goal_pose.quaternion
|
||||
return goal_pose
|
||||
elif (
|
||||
current_solve_state.solve_type == ReacherSolveType.BATCH_ENV_GOALSET
|
||||
and solve_state.solve_type == ReacherSolveType.BATCH_ENV
|
||||
and new_goal_pose.n_goalset <= current_solve_state.n_goalset
|
||||
and new_goal_pose.batch == current_solve_state.batch_size
|
||||
):
|
||||
goal_pose = current_goal_buffer.goal_pose.clone()
|
||||
if len(new_goal_pose.position.shape) == 2:
|
||||
new_goal_pose = new_goal_pose.unsqueeze(1)
|
||||
goal_pose.position[..., :, :] = new_goal_pose.position
|
||||
goal_pose.quaternion[..., :, :] = new_goal_pose.quaternion
|
||||
return goal_pose
|
||||
|
||||
elif (
|
||||
current_solve_state.solve_type == ReacherSolveType.GOALSET
|
||||
and solve_state.solve_type == ReacherSolveType.GOALSET
|
||||
and new_goal_pose.n_goalset <= current_solve_state.n_goalset
|
||||
):
|
||||
goal_pose = current_goal_buffer.goal_pose.clone()
|
||||
goal_pose.position[..., : new_goal_pose.n_goalset, :] = new_goal_pose.position
|
||||
goal_pose.quaternion[..., : new_goal_pose.n_goalset, :] = new_goal_pose.quaternion
|
||||
goal_pose.position[..., new_goal_pose.n_goalset :, :] = new_goal_pose.position[..., :1, :]
|
||||
goal_pose.quaternion[..., new_goal_pose.n_goalset :, :] = new_goal_pose.quaternion[
|
||||
..., :1, :
|
||||
]
|
||||
|
||||
return goal_pose
|
||||
elif (
|
||||
current_solve_state.solve_type == ReacherSolveType.BATCH_GOALSET
|
||||
and solve_state.solve_type == ReacherSolveType.BATCH_GOALSET
|
||||
and new_goal_pose.n_goalset <= current_solve_state.n_goalset
|
||||
and new_goal_pose.batch == current_solve_state.batch_size
|
||||
):
|
||||
goal_pose = current_goal_buffer.goal_pose.clone()
|
||||
goal_pose.position[..., : new_goal_pose.n_goalset, :] = new_goal_pose.position
|
||||
goal_pose.quaternion[..., : new_goal_pose.n_goalset, :] = new_goal_pose.quaternion
|
||||
goal_pose.position[..., new_goal_pose.n_goalset :, :] = new_goal_pose.position[..., :1, :]
|
||||
goal_pose.quaternion[..., new_goal_pose.n_goalset :, :] = new_goal_pose.quaternion[
|
||||
..., :1, :
|
||||
]
|
||||
return goal_pose
|
||||
elif (
|
||||
current_solve_state.solve_type == ReacherSolveType.BATCH_ENV_GOALSET
|
||||
and solve_state.solve_type == ReacherSolveType.BATCH_ENV_GOALSET
|
||||
and new_goal_pose.n_goalset <= current_solve_state.n_goalset
|
||||
and new_goal_pose.batch == current_solve_state.batch_size
|
||||
):
|
||||
goal_pose = current_goal_buffer.goal_pose.clone()
|
||||
goal_pose.position[..., : new_goal_pose.n_goalset, :] = new_goal_pose.position
|
||||
goal_pose.quaternion[..., : new_goal_pose.n_goalset, :] = new_goal_pose.quaternion
|
||||
goal_pose.position[..., new_goal_pose.n_goalset :, :] = new_goal_pose.position[..., :1, :]
|
||||
goal_pose.quaternion[..., new_goal_pose.n_goalset :, :] = new_goal_pose.quaternion[
|
||||
..., :1, :
|
||||
]
|
||||
return goal_pose
|
||||
return None
|
||||
|
||||
@@ -54,7 +54,7 @@ class WrapBase(WrapConfig):
|
||||
def __init__(self, config: Optional[WrapConfig] = None):
|
||||
if config is not None:
|
||||
WrapConfig.__init__(self, **vars(config))
|
||||
self.n_envs = 1
|
||||
self.n_problems = 1
|
||||
self.opt_dt = 0
|
||||
self._rollout_list = None
|
||||
self._opt_rollouts = None
|
||||
@@ -83,11 +83,11 @@ class WrapBase(WrapConfig):
|
||||
debug_list.append(opt.debug_cost)
|
||||
return debug_list
|
||||
|
||||
def update_nenvs(self, n_envs):
|
||||
if n_envs != self.n_envs:
|
||||
self.n_envs = n_envs
|
||||
def update_nproblems(self, n_problems):
|
||||
if n_problems != self.n_problems:
|
||||
self.n_problems = n_problems
|
||||
for opt in self.optimizers:
|
||||
opt.update_nenvs(self.n_envs)
|
||||
opt.update_nproblems(self.n_problems)
|
||||
|
||||
def update_params(self, goal: Goal):
|
||||
with profiler.record_function("wrap_base/safety/update_params"):
|
||||
@@ -117,6 +117,12 @@ class WrapBase(WrapConfig):
|
||||
opt.reset_cuda_graph()
|
||||
self._init_solver = False
|
||||
|
||||
def reset_shape(self):
|
||||
self.safety_rollout.reset_shape()
|
||||
for opt in self.optimizers:
|
||||
opt.reset_shape()
|
||||
self._init_solver = False
|
||||
|
||||
@property
|
||||
def rollout_fn(self):
|
||||
return self.safety_rollout
|
||||
|
||||
Reference in New Issue
Block a user