Add support for older warp versions (<1.0.0)
This commit is contained in:
@@ -85,10 +85,13 @@ class CuroboTorch(torch.nn.Module):
|
||||
kin_state.ee_quaternion,
|
||||
]
|
||||
if x_des is not None:
|
||||
pose_distance = self._robot_world.pose_distance(x_des, kin_state.ee_pose)
|
||||
pose_distance = self._robot_world.pose_distance(
|
||||
x_des, kin_state.ee_pose, resize=True
|
||||
).view(-1, 1)
|
||||
features.append(pose_distance)
|
||||
features.append(x_des.position)
|
||||
features.append(x_des.quaternion)
|
||||
|
||||
features = torch.cat(features, dim=-1)
|
||||
|
||||
return features
|
||||
@@ -114,17 +117,25 @@ class CuroboTorch(torch.nn.Module):
|
||||
|
||||
def loss(self, x_des: Pose, q: torch.Tensor, q_in: torch.Tensor):
|
||||
kin_state = self._robot_world.get_kinematics(q)
|
||||
distance = self._robot_world.pose_distance(x_des, kin_state.ee_pose)
|
||||
d_sdf = self._robot_world.collision_constraint(kin_state.link_spheres_tensor.unsqueeze(1))
|
||||
d_self = self._robot_world.self_collision_cost(kin_state.link_spheres_tensor.unsqueeze(1))
|
||||
distance = self._robot_world.pose_distance(x_des, kin_state.ee_pose, resize=True)
|
||||
d_sdf = self._robot_world.collision_constraint(
|
||||
kin_state.link_spheres_tensor.unsqueeze(1)
|
||||
).view(-1)
|
||||
d_self = self._robot_world.self_collision_cost(
|
||||
kin_state.link_spheres_tensor.unsqueeze(1)
|
||||
).view(-1)
|
||||
loss = 0.1 * torch.linalg.norm(q_in - q, dim=-1) + distance + 100.0 * (d_self + d_sdf)
|
||||
return loss
|
||||
|
||||
def val_loss(self, x_des: Pose, q: torch.Tensor, q_in: torch.Tensor):
|
||||
kin_state = self._robot_world.get_kinematics(q)
|
||||
distance = self._robot_world.pose_distance(x_des, kin_state.ee_pose)
|
||||
d_sdf = self._robot_world.collision_constraint(kin_state.link_spheres_tensor.unsqueeze(1))
|
||||
d_self = self._robot_world.self_collision_cost(kin_state.link_spheres_tensor.unsqueeze(1))
|
||||
distance = self._robot_world.pose_distance(x_des, kin_state.ee_pose, resize=True)
|
||||
d_sdf = self._robot_world.collision_constraint(
|
||||
kin_state.link_spheres_tensor.unsqueeze(1)
|
||||
).view(-1)
|
||||
d_self = self._robot_world.self_collision_cost(
|
||||
kin_state.link_spheres_tensor.unsqueeze(1)
|
||||
).view(-1)
|
||||
loss = 10.0 * (d_self + d_sdf) + distance
|
||||
return loss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user