constrained planning, robot segmentation
This commit is contained in:
@@ -153,6 +153,8 @@ def get_pose_distance(
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
offset_waypoint,
|
||||
offset_tstep_fraction,
|
||||
batch_pose_idx,
|
||||
batch_size,
|
||||
horizon,
|
||||
@@ -161,6 +163,7 @@ def get_pose_distance(
|
||||
write_grad=False,
|
||||
write_distance=False,
|
||||
use_metric=False,
|
||||
project_distance=True,
|
||||
):
|
||||
if batch_pose_idx.shape[0] != batch_size:
|
||||
raise ValueError("Index buffer size is different from batch size")
|
||||
@@ -181,6 +184,8 @@ def get_pose_distance(
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
offset_waypoint,
|
||||
offset_tstep_fraction,
|
||||
batch_pose_idx,
|
||||
batch_size,
|
||||
horizon,
|
||||
@@ -189,6 +194,7 @@ def get_pose_distance(
|
||||
write_grad,
|
||||
write_distance,
|
||||
use_metric,
|
||||
project_distance,
|
||||
)
|
||||
|
||||
out_distance = r[0]
|
||||
@@ -229,6 +235,331 @@ def get_pose_distance_backward(
|
||||
return r[0], r[1]
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def backward_PoseError_jit(grad_g_dist, grad_out_distance, weight, g_vec):
|
||||
grad_vec = grad_g_dist + (grad_out_distance * weight)
|
||||
grad = 1.0 * (grad_vec).unsqueeze(-1) * g_vec
|
||||
return grad
|
||||
|
||||
|
||||
# full method:
|
||||
@torch.jit.script
|
||||
def backward_full_PoseError_jit(
|
||||
grad_out_distance, grad_g_dist, grad_r_err, p_w, q_w, g_vec_p, g_vec_q
|
||||
):
|
||||
p_grad = (grad_g_dist + (grad_out_distance * p_w)).unsqueeze(-1) * g_vec_p
|
||||
q_grad = (grad_r_err + (grad_out_distance * q_w)).unsqueeze(-1) * g_vec_q
|
||||
# p_grad = ((grad_out_distance * p_w)).unsqueeze(-1) * g_vec_p
|
||||
# q_grad = ((grad_out_distance * q_w)).unsqueeze(-1) * g_vec_q
|
||||
|
||||
return p_grad, q_grad
|
||||
|
||||
|
||||
class PoseErrorDistance(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
current_position,
|
||||
goal_position,
|
||||
current_quat,
|
||||
goal_quat,
|
||||
vec_weight,
|
||||
weight,
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
offset_waypoint,
|
||||
offset_tstep_fraction,
|
||||
batch_pose_idx,
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
out_p_grad,
|
||||
out_q_grad,
|
||||
batch_size,
|
||||
horizon,
|
||||
mode, # =PoseErrorType.BATCH_GOAL.value,
|
||||
num_goals,
|
||||
use_metric, # =False,
|
||||
project_distance, # =True,
|
||||
):
|
||||
# out_distance = current_position[..., 0].detach().clone() * 0.0
|
||||
# out_position_distance = out_distance.detach().clone()
|
||||
# out_rotation_distance = out_distance.detach().clone()
|
||||
# out_vec = (
|
||||
# torch.cat((current_position.detach().clone(), current_quat.detach().clone()), dim=-1)
|
||||
# * 0.0
|
||||
# )
|
||||
# out_idx = out_distance.clone().to(dtype=torch.long)
|
||||
|
||||
(
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
) = get_pose_distance(
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
current_position.contiguous(),
|
||||
goal_position,
|
||||
current_quat.contiguous(),
|
||||
goal_quat,
|
||||
vec_weight,
|
||||
weight,
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
offset_waypoint,
|
||||
offset_tstep_fraction,
|
||||
batch_pose_idx,
|
||||
batch_size,
|
||||
horizon,
|
||||
mode,
|
||||
num_goals,
|
||||
current_position.requires_grad,
|
||||
True,
|
||||
use_metric,
|
||||
project_distance,
|
||||
)
|
||||
ctx.save_for_backward(out_p_vec, out_r_vec, weight, out_p_grad, out_q_grad)
|
||||
return out_distance, out_position_distance, out_rotation_distance, out_idx # .view(-1,1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out_distance, grad_g_dist, grad_r_err, grad_out_idx):
|
||||
(g_vec_p, g_vec_q, weight, out_grad_p, out_grad_q) = ctx.saved_tensors
|
||||
pos_grad = None
|
||||
quat_grad = None
|
||||
batch_size = g_vec_p.shape[0] * g_vec_p.shape[1]
|
||||
if ctx.needs_input_grad[0] and ctx.needs_input_grad[2]:
|
||||
pos_grad, quat_grad = get_pose_distance_backward(
|
||||
out_grad_p,
|
||||
out_grad_q,
|
||||
grad_out_distance.contiguous(),
|
||||
grad_g_dist.contiguous(),
|
||||
grad_r_err.contiguous(),
|
||||
weight,
|
||||
g_vec_p,
|
||||
g_vec_q,
|
||||
batch_size,
|
||||
use_distance=True,
|
||||
)
|
||||
|
||||
elif ctx.needs_input_grad[0]:
|
||||
pos_grad = backward_PoseError_jit(grad_g_dist, grad_out_distance, weight[1], g_vec_p)
|
||||
|
||||
elif ctx.needs_input_grad[2]:
|
||||
quat_grad = backward_PoseError_jit(grad_r_err, grad_out_distance, weight[0], g_vec_q)
|
||||
|
||||
return (
|
||||
pos_grad,
|
||||
None,
|
||||
quat_grad,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class PoseError(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
current_position: torch.Tensor,
|
||||
goal_position: torch.Tensor,
|
||||
current_quat: torch.Tensor,
|
||||
goal_quat,
|
||||
vec_weight,
|
||||
weight,
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
offset_waypoint,
|
||||
offset_tstep_fraction,
|
||||
batch_pose_idx,
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
out_p_grad,
|
||||
out_q_grad,
|
||||
batch_size,
|
||||
horizon,
|
||||
mode,
|
||||
num_goals,
|
||||
use_metric,
|
||||
project_distance,
|
||||
return_loss,
|
||||
):
|
||||
"""Compute error in pose
|
||||
|
||||
_extended_summary_
|
||||
|
||||
Args:
|
||||
ctx: _description_
|
||||
current_position: _description_
|
||||
goal_position: _description_
|
||||
current_quat: _description_
|
||||
goal_quat: _description_
|
||||
vec_weight: _description_
|
||||
weight: _description_
|
||||
vec_convergence: _description_
|
||||
run_weight: _description_
|
||||
run_vec_weight: _description_
|
||||
offset_waypoint: _description_
|
||||
offset_tstep_fraction: _description_
|
||||
batch_pose_idx: _description_
|
||||
out_distance: _description_
|
||||
out_position_distance: _description_
|
||||
out_rotation_distance: _description_
|
||||
out_p_vec: _description_
|
||||
out_r_vec: _description_
|
||||
out_idx: _description_
|
||||
out_p_grad: _description_
|
||||
out_q_grad: _description_
|
||||
batch_size: _description_
|
||||
horizon: _description_
|
||||
mode: _description_
|
||||
num_goals: _description_
|
||||
use_metric: _description_
|
||||
project_distance: _description_
|
||||
return_loss: _description_
|
||||
|
||||
Returns:
|
||||
_description_
|
||||
"""
|
||||
# out_distance = current_position[..., 0].detach().clone() * 0.0
|
||||
# out_position_distance = out_distance.detach().clone()
|
||||
# out_rotation_distance = out_distance.detach().clone()
|
||||
# out_vec = (
|
||||
# torch.cat((current_position.detach().clone(), current_quat.detach().clone()), dim=-1)
|
||||
# * 0.0
|
||||
# )
|
||||
# out_idx = out_distance.clone().to(dtype=torch.long)
|
||||
ctx.return_loss = return_loss
|
||||
(
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
) = get_pose_distance(
|
||||
out_distance,
|
||||
out_position_distance,
|
||||
out_rotation_distance,
|
||||
out_p_vec,
|
||||
out_r_vec,
|
||||
out_idx,
|
||||
current_position.contiguous(),
|
||||
goal_position,
|
||||
current_quat.contiguous(),
|
||||
goal_quat,
|
||||
vec_weight,
|
||||
weight,
|
||||
vec_convergence,
|
||||
run_weight,
|
||||
run_vec_weight,
|
||||
offset_waypoint,
|
||||
offset_tstep_fraction,
|
||||
batch_pose_idx,
|
||||
batch_size,
|
||||
horizon,
|
||||
mode,
|
||||
num_goals,
|
||||
current_position.requires_grad,
|
||||
False,
|
||||
use_metric,
|
||||
project_distance,
|
||||
)
|
||||
ctx.save_for_backward(out_p_vec, out_r_vec)
|
||||
return out_distance
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out_distance): # , grad_g_dist, grad_r_err, grad_out_idx):
|
||||
pos_grad = None
|
||||
quat_grad = None
|
||||
if ctx.needs_input_grad[0] and ctx.needs_input_grad[2]:
|
||||
(g_vec_p, g_vec_q) = ctx.saved_tensors
|
||||
pos_grad = g_vec_p
|
||||
quat_grad = g_vec_q
|
||||
if ctx.return_loss:
|
||||
pos_grad = pos_grad * grad_out_distance.unsqueeze(1)
|
||||
quat_grad = quat_grad * grad_out_distance.unsqueeze(1)
|
||||
|
||||
elif ctx.needs_input_grad[0]:
|
||||
(g_vec_p, g_vec_q) = ctx.saved_tensors
|
||||
|
||||
pos_grad = g_vec_p
|
||||
if ctx.return_loss:
|
||||
pos_grad = pos_grad * grad_out_distance.unsqueeze(1)
|
||||
elif ctx.needs_input_grad[2]:
|
||||
(g_vec_p, g_vec_q) = ctx.saved_tensors
|
||||
|
||||
quat_grad = g_vec_q
|
||||
if ctx.return_loss:
|
||||
quat_grad = quat_grad * grad_out_distance.unsqueeze(1)
|
||||
return (
|
||||
pos_grad,
|
||||
None,
|
||||
quat_grad,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class SdfSphereOBB(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user