constrained planning, robot segmentation

This commit is contained in:
Balakumar Sundaralingam
2024-02-22 21:45:47 -08:00
parent 88eac64edc
commit bafdf80c05
102 changed files with 12440 additions and 8112 deletions

View File

@@ -153,6 +153,8 @@ def get_pose_distance(
vec_convergence,
run_weight,
run_vec_weight,
offset_waypoint,
offset_tstep_fraction,
batch_pose_idx,
batch_size,
horizon,
@@ -161,6 +163,7 @@ def get_pose_distance(
write_grad=False,
write_distance=False,
use_metric=False,
project_distance=True,
):
if batch_pose_idx.shape[0] != batch_size:
raise ValueError("Index buffer size is different from batch size")
@@ -181,6 +184,8 @@ def get_pose_distance(
vec_convergence,
run_weight,
run_vec_weight,
offset_waypoint,
offset_tstep_fraction,
batch_pose_idx,
batch_size,
horizon,
@@ -189,6 +194,7 @@ def get_pose_distance(
write_grad,
write_distance,
use_metric,
project_distance,
)
out_distance = r[0]
@@ -229,6 +235,331 @@ def get_pose_distance_backward(
return r[0], r[1]
@torch.jit.script
def backward_PoseError_jit(grad_g_dist, grad_out_distance, weight, g_vec):
grad_vec = grad_g_dist + (grad_out_distance * weight)
grad = 1.0 * (grad_vec).unsqueeze(-1) * g_vec
return grad
# full method:
@torch.jit.script
def backward_full_PoseError_jit(
grad_out_distance, grad_g_dist, grad_r_err, p_w, q_w, g_vec_p, g_vec_q
):
p_grad = (grad_g_dist + (grad_out_distance * p_w)).unsqueeze(-1) * g_vec_p
q_grad = (grad_r_err + (grad_out_distance * q_w)).unsqueeze(-1) * g_vec_q
# p_grad = ((grad_out_distance * p_w)).unsqueeze(-1) * g_vec_p
# q_grad = ((grad_out_distance * q_w)).unsqueeze(-1) * g_vec_q
return p_grad, q_grad
class PoseErrorDistance(torch.autograd.Function):
@staticmethod
def forward(
ctx,
current_position,
goal_position,
current_quat,
goal_quat,
vec_weight,
weight,
vec_convergence,
run_weight,
run_vec_weight,
offset_waypoint,
offset_tstep_fraction,
batch_pose_idx,
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
out_p_grad,
out_q_grad,
batch_size,
horizon,
mode, # =PoseErrorType.BATCH_GOAL.value,
num_goals,
use_metric, # =False,
project_distance, # =True,
):
# out_distance = current_position[..., 0].detach().clone() * 0.0
# out_position_distance = out_distance.detach().clone()
# out_rotation_distance = out_distance.detach().clone()
# out_vec = (
# torch.cat((current_position.detach().clone(), current_quat.detach().clone()), dim=-1)
# * 0.0
# )
# out_idx = out_distance.clone().to(dtype=torch.long)
(
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
) = get_pose_distance(
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
current_position.contiguous(),
goal_position,
current_quat.contiguous(),
goal_quat,
vec_weight,
weight,
vec_convergence,
run_weight,
run_vec_weight,
offset_waypoint,
offset_tstep_fraction,
batch_pose_idx,
batch_size,
horizon,
mode,
num_goals,
current_position.requires_grad,
True,
use_metric,
project_distance,
)
ctx.save_for_backward(out_p_vec, out_r_vec, weight, out_p_grad, out_q_grad)
return out_distance, out_position_distance, out_rotation_distance, out_idx # .view(-1,1)
@staticmethod
def backward(ctx, grad_out_distance, grad_g_dist, grad_r_err, grad_out_idx):
(g_vec_p, g_vec_q, weight, out_grad_p, out_grad_q) = ctx.saved_tensors
pos_grad = None
quat_grad = None
batch_size = g_vec_p.shape[0] * g_vec_p.shape[1]
if ctx.needs_input_grad[0] and ctx.needs_input_grad[2]:
pos_grad, quat_grad = get_pose_distance_backward(
out_grad_p,
out_grad_q,
grad_out_distance.contiguous(),
grad_g_dist.contiguous(),
grad_r_err.contiguous(),
weight,
g_vec_p,
g_vec_q,
batch_size,
use_distance=True,
)
elif ctx.needs_input_grad[0]:
pos_grad = backward_PoseError_jit(grad_g_dist, grad_out_distance, weight[1], g_vec_p)
elif ctx.needs_input_grad[2]:
quat_grad = backward_PoseError_jit(grad_r_err, grad_out_distance, weight[0], g_vec_q)
return (
pos_grad,
None,
quat_grad,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class PoseError(torch.autograd.Function):
@staticmethod
def forward(
ctx,
current_position: torch.Tensor,
goal_position: torch.Tensor,
current_quat: torch.Tensor,
goal_quat,
vec_weight,
weight,
vec_convergence,
run_weight,
run_vec_weight,
offset_waypoint,
offset_tstep_fraction,
batch_pose_idx,
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
out_p_grad,
out_q_grad,
batch_size,
horizon,
mode,
num_goals,
use_metric,
project_distance,
return_loss,
):
"""Compute error in pose
_extended_summary_
Args:
ctx: _description_
current_position: _description_
goal_position: _description_
current_quat: _description_
goal_quat: _description_
vec_weight: _description_
weight: _description_
vec_convergence: _description_
run_weight: _description_
run_vec_weight: _description_
offset_waypoint: _description_
offset_tstep_fraction: _description_
batch_pose_idx: _description_
out_distance: _description_
out_position_distance: _description_
out_rotation_distance: _description_
out_p_vec: _description_
out_r_vec: _description_
out_idx: _description_
out_p_grad: _description_
out_q_grad: _description_
batch_size: _description_
horizon: _description_
mode: _description_
num_goals: _description_
use_metric: _description_
project_distance: _description_
return_loss: _description_
Returns:
_description_
"""
# out_distance = current_position[..., 0].detach().clone() * 0.0
# out_position_distance = out_distance.detach().clone()
# out_rotation_distance = out_distance.detach().clone()
# out_vec = (
# torch.cat((current_position.detach().clone(), current_quat.detach().clone()), dim=-1)
# * 0.0
# )
# out_idx = out_distance.clone().to(dtype=torch.long)
ctx.return_loss = return_loss
(
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
) = get_pose_distance(
out_distance,
out_position_distance,
out_rotation_distance,
out_p_vec,
out_r_vec,
out_idx,
current_position.contiguous(),
goal_position,
current_quat.contiguous(),
goal_quat,
vec_weight,
weight,
vec_convergence,
run_weight,
run_vec_weight,
offset_waypoint,
offset_tstep_fraction,
batch_pose_idx,
batch_size,
horizon,
mode,
num_goals,
current_position.requires_grad,
False,
use_metric,
project_distance,
)
ctx.save_for_backward(out_p_vec, out_r_vec)
return out_distance
@staticmethod
def backward(ctx, grad_out_distance): # , grad_g_dist, grad_r_err, grad_out_idx):
pos_grad = None
quat_grad = None
if ctx.needs_input_grad[0] and ctx.needs_input_grad[2]:
(g_vec_p, g_vec_q) = ctx.saved_tensors
pos_grad = g_vec_p
quat_grad = g_vec_q
if ctx.return_loss:
pos_grad = pos_grad * grad_out_distance.unsqueeze(1)
quat_grad = quat_grad * grad_out_distance.unsqueeze(1)
elif ctx.needs_input_grad[0]:
(g_vec_p, g_vec_q) = ctx.saved_tensors
pos_grad = g_vec_p
if ctx.return_loss:
pos_grad = pos_grad * grad_out_distance.unsqueeze(1)
elif ctx.needs_input_grad[2]:
(g_vec_p, g_vec_q) = ctx.saved_tensors
quat_grad = g_vec_q
if ctx.return_loss:
quat_grad = quat_grad * grad_out_distance.unsqueeze(1)
return (
pos_grad,
None,
quat_grad,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class SdfSphereOBB(torch.autograd.Function):
@staticmethod
def forward(