update to 0.6.2
This commit is contained in:
@@ -64,7 +64,9 @@ def get_stomp_cov(
|
||||
Coefficients from here: https://en.wikipedia.org/wiki/Finite_difference_coefficient
|
||||
More info here: https://github.com/ros-industrial/stomp_ros/blob/7fe40fbe6ad446459d8d4889916c64e276dbf882/stomp_core/src/utils.cpp#L36
|
||||
"""
|
||||
cov, scale_tril, scaled_M = get_stomp_cov_jit(horizon, d_action, cov_mode)
|
||||
cov, scale_tril, scaled_M = get_stomp_cov_jit(
|
||||
horizon, d_action, cov_mode, device=tensor_args.device
|
||||
)
|
||||
cov = tensor_args.to_device(cov)
|
||||
scale_tril = tensor_args.to_device(scale_tril)
|
||||
if RETURN_M:
|
||||
@@ -77,13 +79,16 @@ def get_stomp_cov_jit(
|
||||
horizon: int,
|
||||
d_action: int,
|
||||
cov_mode: str = "acc",
|
||||
device: torch.device = torch.device("cuda:0"),
|
||||
):
|
||||
# This function can lead to nans. There are checks to raise an error when nan occurs.
|
||||
vel_fd_array = [0.0, 0.0, 1.0, -2.0, 1.0, 0.0, 0.0]
|
||||
|
||||
fd_array = vel_fd_array
|
||||
A = torch.zeros(
|
||||
(d_action * horizon, d_action * horizon),
|
||||
dtype=torch.float64,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if cov_mode == "vel":
|
||||
@@ -117,14 +122,17 @@ def get_stomp_cov_jit(
|
||||
A[k * horizon + i, k * horizon + index] = fd_array[j + 3]
|
||||
|
||||
R = torch.matmul(A.transpose(-2, -1), A)
|
||||
|
||||
M = torch.inverse(R)
|
||||
scaled_M = (1 / horizon) * M / (torch.max(torch.abs(M), dim=1)[0].unsqueeze(0))
|
||||
cov = M / torch.max(torch.abs(M))
|
||||
|
||||
# also compute the cholesky decomposition:
|
||||
# scale_tril = torch.zeros((d_action * horizon, d_action * horizon), **tensor_args)
|
||||
scale_tril = torch.linalg.cholesky(cov)
|
||||
if (cov == cov.T).all() and (torch.linalg.eigvals(cov).real >= 0).all():
|
||||
scale_tril = torch.linalg.cholesky(cov)
|
||||
else:
|
||||
scale_tril = cov
|
||||
|
||||
"""
|
||||
k = 0
|
||||
act_cov_matrix = cov[k * horizon:k * horizon + horizon, k * horizon:k * horizon + horizon]
|
||||
|
||||
Reference in New Issue
Block a user