release repository

This commit is contained in:
Balakumar Sundaralingam
2023-10-26 04:17:19 -07:00
commit 07e6ccfc91
287 changed files with 70659 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#

View File

@@ -0,0 +1,14 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
"""
This module contains code for cuda accelerated kinematics.
"""

View File

@@ -0,0 +1,215 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from dataclasses import dataclass
from typing import Optional
# Third Party
import torch
import torch.autograd.profiler as profiler
# CuRobo
from curobo.curobolib.opt import LBFGScu
from curobo.opt.newton.newton_base import NewtonOptBase, NewtonOptConfig
from curobo.util.logger import log_warn
# kernel for l-bfgs:
@torch.jit.script
def compute_step_direction(
alpha_buffer, rho_buffer, y_buffer, s_buffer, grad_q, m: int, epsilon, stable_mode: bool = True
):
# m = 15 (int)
# y_buffer, s_buffer: m x b x 175
# q, grad_q: b x 175
# rho_buffer: m x b x 1
# alpha_buffer: m x b x 1 # this can be dynamically created
gq = grad_q.detach().clone()
rho_s = rho_buffer * s_buffer.transpose(-1, -2) # batched m_scalar-m_vector product
for i in range(m - 1, -1, -1):
alpha_buffer[i] = rho_s[i] @ gq # batched vector-vector dot product
gq = gq - (alpha_buffer[i] * y_buffer[i]) # batched scalar-vector product
var1 = (s_buffer[-1].transpose(-1, -2) @ y_buffer[-1]) / (
y_buffer[-1].transpose(-1, -2) @ y_buffer[-1]
)
if stable_mode:
var1 = torch.nan_to_num(var1, epsilon, epsilon, epsilon)
gamma = torch.nn.functional.relu(var1) # + epsilon
r = gamma * gq.clone()
rho_y = rho_buffer * y_buffer.transpose(-1, -2) # batched m_scalar-m_vector product
for i in range(m):
# batched dot product, scalar-vector product
r = r + (alpha_buffer[i] - (rho_y[i] @ r)) * s_buffer[i]
return -1.0 * r
@dataclass
class LBFGSOptConfig(NewtonOptConfig):
history: int = 10
epsilon: float = 0.01
use_cuda_kernel: bool = True
stable_mode: bool = True
def __post_init__(self):
return super().__post_init__()
class LBFGSOpt(NewtonOptBase, LBFGSOptConfig):
@profiler.record_function("lbfgs_opt/init")
def __init__(self, config: Optional[LBFGSOptConfig] = None):
if config is not None:
LBFGSOptConfig.__init__(self, **vars(config))
NewtonOptBase.__init__(self)
if self.d_opt >= 1024 or self.history >= 512:
log_warn("LBFGS: Not using LBFGS Cuda Kernel as d_opt>1024 or history>=512")
self.use_cuda_kernel = False
if self.history > self.d_opt:
log_warn("LBFGS: history >= d_opt, reducing history to d_opt-1")
self.history = self.d_opt
@profiler.record_function("lbfgs/reset")
def reset(self):
self.x_0[:] = 0.0
self.grad_0[:] = 0.0
self.s_buffer[:] = 0.0
self.y_buffer[:] = 0.0
self.rho_buffer[:] = 0.0
self.alpha_buffer[:] = 0.0
self.step_q_buffer[:] = 0.0
return super().reset()
def update_nenvs(self, n_envs):
self.init_hessian(b=n_envs)
return super().update_nenvs(n_envs)
def init_hessian(self, b=1):
self.x_0 = torch.zeros(
(b, self.d_opt, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self.grad_0 = torch.zeros(
(b, self.d_opt, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self.y_buffer = torch.zeros(
(self.history, b, self.d_opt, 1),
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
) # + 1.0
self.s_buffer = torch.zeros(
(self.history, b, self.d_opt, 1),
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
) # + 1.0
self.rho_buffer = torch.zeros(
(self.history, b, 1, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
) # + 1.0
self.step_q_buffer = torch.zeros(
(b, self.d_opt), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self.alpha_buffer = torch.zeros(
(self.history, b, 1, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
) # + 1.0
@torch.no_grad()
def _get_step_direction(self, cost, q, grad_q):
if self.use_cuda_kernel:
with profiler.record_function("lbfgs/fused"):
dq = LBFGScu._call_cuda(
self.step_q_buffer,
self.rho_buffer,
self.y_buffer,
self.s_buffer,
q,
grad_q,
self.x_0,
self.grad_0,
self.epsilon,
self.stable_mode,
)
else:
grad_q = grad_q.transpose(-1, -2)
q = q.unsqueeze(-1)
self._update_buffers(q, grad_q)
dq = compute_step_direction(
self.alpha_buffer,
self.rho_buffer,
self.y_buffer,
self.s_buffer,
grad_q,
self.history,
self.epsilon,
self.stable_mode,
)
dq = dq.view(-1, self.d_opt)
return dq
def _update_q(self, grad_q):
q = grad_q.detach().clone()
rho_s = self.rho_buffer * self.s_buffer.transpose(-1, -2)
for i in range(self.history - 1, -1, -1):
self.alpha_buffer[i] = rho_s[i] @ q
q = q - (self.alpha_buffer[i] * self.y_buffer[i])
return q
def _update_r(self, r_init):
r = r_init.clone()
rho_y = self.rho_buffer * self.y_buffer.transpose(-1, -2)
for i in range(self.history):
r = r + self.s_buffer[i] * (self.alpha_buffer[i] - rho_y[i] @ r)
return -1.0 * r
def _update_buffers(self, q, grad_q):
y = grad_q - self.grad_0
s = q - self.x_0
rho = 1 / (y.transpose(-1, -2) @ s)
if self.stable_mode:
rho = torch.nan_to_num(rho, 0, 0, 0)
self.s_buffer[0] = s
self.s_buffer[:] = torch.roll(self.s_buffer, -1, dims=0)
self.y_buffer[0] = y
self.y_buffer[:] = torch.roll(self.y_buffer, -1, dims=0) # .copy_(y_buff)
self.rho_buffer[0] = rho
self.rho_buffer[:] = torch.roll(self.rho_buffer, -1, dims=0)
self.x_0.copy_(q)
self.grad_0.copy_(grad_q)
def _shift(self, shift_steps=0):
"""Shift the optimizer by shift_steps * d_opt
Args:
shift_steps (int, optional): _description_. Defaults to 0.
"""
if shift_steps == 0:
return
self.reset()
shift_d = shift_steps * self.d_action
self.x_0 = self._shift_buffer(self.x_0, shift_d, shift_steps)
self.grad_0 = self._shift_buffer(self.grad_0, shift_d, shift_steps)
self.y_buffer = self._shift_buffer(self.y_buffer, shift_d, shift_steps)
self.s_buffer = self._shift_buffer(self.s_buffer, shift_d, shift_steps)
super()._shift(shift_steps=shift_steps)
def _shift_buffer(self, buffer, shift_d, shift_steps: int = 1):
buffer = buffer.roll(-shift_d, -2)
end_value = -(shift_steps - 1) * self.d_action
if end_value == 0:
end_value = buffer.shape[-2]
buffer[..., -shift_d:end_value, :] = buffer[
..., -shift_d - self.d_action : -shift_d, :
].clone()
return buffer

View File

@@ -0,0 +1,604 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
import math
import time
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Union
# Third Party
import torch
import torch.autograd.profiler as profiler
# CuRobo
from curobo.curobolib.ls import update_best, wolfe_line_search
from curobo.opt.opt_base import Optimizer, OptimizerConfig
from curobo.rollout.dynamics_model.integration_utils import build_fd_matrix
from curobo.types.base import TensorDeviceType
from curobo.types.tensor import T_BDOF, T_BHDOF_float, T_BHValue_float, T_BValue_float, T_HDOF_float
class LineSearchType(Enum):
GREEDY = "greedy"
ARMIJO = "armijo"
WOLFE = "wolfe"
STRONG_WOLFE = "strong_wolfe"
APPROX_WOLFE = "approx_wolfe"
@dataclass
class NewtonOptConfig(OptimizerConfig):
line_search_scale: List[int]
cost_convergence: float
cost_delta_threshold: float
fixed_iters: bool
inner_iters: int # used for cuda graph
line_search_type: LineSearchType
use_cuda_line_search_kernel: bool
use_cuda_update_best_kernel: bool
min_iters: int
step_scale: float
last_best: float = 0
use_temporal_smooth: bool = False
cost_relative_threshold: float = 0.999
# use_update_best_kernel: bool
# c_1: float
# c_2: float
def __post_init__(self):
self.num_particles = len(self.line_search_scale)
self.line_search_type = LineSearchType(self.line_search_type)
if self.fixed_iters:
self.cost_delta_threshold = 0.0001
self.cost_relative_threshold = 1.0
return super().__post_init__()
class NewtonOptBase(Optimizer, NewtonOptConfig):
@profiler.record_function("newton_opt/init")
def __init__(
self,
config: Optional[NewtonOptConfig] = None,
):
if config is not None:
NewtonOptConfig.__init__(self, **vars(config))
self.d_opt = self.horizon * self.d_action
self.line_scale = self._create_box_line_search(self.line_search_scale)
Optimizer.__init__(self)
self.i = -1
self.outer_iters = math.ceil(self.n_iters / self.inner_iters)
# create line search
self.update_nenvs(self.n_envs)
self.reset()
# reshape action lows and highs:
self.action_lows = self.action_lows.repeat(self.horizon)
self.action_highs = self.action_highs.repeat(self.horizon)
self.action_range = self.action_highs - self.action_lows
self.action_step_max = self.step_scale * torch.abs(self.action_range)
self.c_1 = 1e-5
self.c_2 = 0.9
self._out_m_idx = None
self._out_best_x = None
self._out_best_c = None
self._out_best_grad = None
self.cu_opt_graph = None
if self.d_opt >= 1024:
self.use_cuda_line_search_kernel = False
if self.use_temporal_smooth:
self._temporal_mat = build_fd_matrix(
self.horizon, order=2, device=self.tensor_args.device, dtype=self.tensor_args.dtype
).unsqueeze(0)
eye_mat = torch.eye(
self.horizon, device=self.tensor_args.device, dtype=self.tensor_args.dtype
).unsqueeze(0)
self._temporal_mat += eye_mat
def reset_cuda_graph(self):
if self.cu_opt_graph is not None:
self.cu_opt_graph.reset()
super().reset_cuda_graph()
@torch.no_grad()
def _get_step_direction(self, cost, q, grad_q):
"""
Reimplement this function in derived class. Gradient Descent is implemented here.
"""
return -1.0 * grad_q.view(-1, self.d_opt)
def _shift(self, shift_steps=1):
# TODO: shift best q?:
self.best_cost[:] = 500000.0
self.best_iteration[:] = 0
self.current_iteration[:] = 0
return True
def _optimize(self, q: T_BHDOF_float, shift_steps=0, n_iters=None):
with profiler.record_function("newton_base/shift"):
self._shift(shift_steps)
# reshape q:
if self.store_debug:
self.debug.append(q.view(-1, self.horizon, self.d_action).clone())
with profiler.record_function("newton_base/init_opt"):
q = q.view(self.n_envs, self.horizon * self.d_action)
grad_q = q.detach() * 0.0
# run opt graph
if not self.cu_opt_init:
self._initialize_opt_iters_graph(q, grad_q, shift_steps=shift_steps)
for i in range(self.outer_iters):
best_q, best_cost, q, grad_q = self._call_opt_iters_graph(q, grad_q)
if (
not self.fixed_iters
and self.use_cuda_update_best_kernel
and (i + 1) * self.inner_iters >= self.min_iters
):
if check_convergence(self.best_iteration, self.current_iteration, self.last_best):
break
best_q = best_q.view(self.n_envs, self.horizon, self.d_action)
return best_q
def reset(self):
with profiler.record_function("newton/reset"):
self.i = -1
self._opt_finished = False
self.best_cost[:] = 500000.0
self.best_iteration[:] = 0
super().reset()
def _opt_iters(self, q, grad_q, shift_steps=0):
q = q.detach() # .clone()
grad_q = grad_q.detach() # .clone()
for _ in range(self.inner_iters):
self.i += 1
cost_n, q, grad_q = self._opt_step(q.detach(), grad_q.detach())
if self.store_debug:
self.debug.append(self.best_q.view(-1, self.horizon, self.d_action).clone())
self.debug_cost.append(self.best_cost.detach().view(-1, 1).clone())
return self.best_q.detach(), self.best_cost.detach(), q.detach(), grad_q.detach()
def _opt_step(self, q, grad_q):
with profiler.record_function("newton/line_search"):
q_n, cost_n, grad_q_n = self._approx_line_search(q, grad_q)
with profiler.record_function("newton/step_direction"):
grad_q = self._get_step_direction(cost_n, q_n, grad_q_n)
with profiler.record_function("newton/update_best"):
self._update_best(q_n, grad_q_n, cost_n)
return cost_n, q_n, grad_q
def clip_bounds(self, x):
x = torch.clamp(x, self.action_lows, self.action_highs)
return x
def scale_step_direction(self, dx):
if self.use_temporal_smooth:
dx_v = dx.view(-1, self.horizon, self.d_action)
dx_new = self._temporal_mat @ dx_v # 1,h,h x b, h, dof -> b, h, dof
dx = dx_new.view(-1, self.horizon * self.d_action)
dx_scaled = scale_action(dx, self.action_step_max)
return dx_scaled
def project_bounds(self, x):
# Find maximum value along all joint angles:
max_tensor = torch.maximum((self.action_lows - x), (x - self.action_highs)) / (
(self.action_highs - self.action_lows)
)
# all values greater than 0 are out of bounds:
scale = torch.max(max_tensor, dim=-1, keepdim=True)[0]
scale = torch.nan_to_num(scale, nan=1.0, posinf=1.0, neginf=1.0)
scale[scale <= 0.0] = 1.0
x = (1.0 / scale) * x
# If we hit nans in scaling:
x = torch.nan_to_num(x, nan=0.0)
#
# max_val = torch.max()
# x = torch.clamp(x, self.action_lows, self.action_highs)
return x
def _compute_cost_gradient(self, x):
x_n = x.detach().requires_grad_(True)
x_in = x_n.view(
self.n_envs * self.num_particles, self.rollout_fn.horizon, self.rollout_fn.d_action
)
trajectories = self.rollout_fn(x_in) # x_n = (batch*line_search_scale) x horizon x d_action
cost = torch.sum(
trajectories.costs.view(self.n_envs, self.num_particles, self.rollout_fn.horizon),
dim=-1,
keepdim=True,
)
g_x = cost.backward(gradient=self.l_vec, retain_graph=False)
g_x = x_n.grad.detach()
return (
cost,
g_x,
) # cost: [n_envs, n_particles, 1], g_x: [n_envs, n_particles, horizon*d_action]
def _wolfe_line_search(self, x, step_direction):
# x_set = get_x_set_jit(step_direction, x, self.alpha_list, self.action_lows, self.action_highs)
step_direction = step_direction.detach()
step_vec = step_direction.unsqueeze(-2)
x_set = get_x_set_jit(step_vec, x, self.alpha_list, self.action_lows, self.action_highs)
# x_set = x.unsqueeze(-2) + self.alpha_list * step_vec
# x_set = self.clip_bounds(x_set)
# x_set = self.project_bounds(x_set)
x_set = x_set.detach().requires_grad_(True)
b, h, _ = x_set.shape
c, g_x = self._compute_cost_gradient(x_set)
with torch.no_grad():
if not self.use_cuda_line_search_kernel:
c_0 = c[:, 0:1]
step_vec_T = step_vec.transpose(-1, -2)
g_full_step = g_x @ step_vec_T
# g_step = g_x[:,0:1] @ step_vec_T
g_step = g_full_step[:, 0:1]
# condition 1:
wolfe_1 = c <= c_0 + self.c_1 * self.zero_alpha_list * g_step # dot product
# condition 2:
if self.line_search_type == LineSearchType.STRONG_WOLFE:
wolfe_2 = torch.abs(g_full_step) <= self.c_2 * torch.abs(g_step)
else:
wolfe_2 = g_full_step >= self.c_2 * g_step # dot product
wolfe = torch.logical_and(wolfe_1, wolfe_2)
# get the last occurence of true (this will be the largest admissable alpha value):
# wolfe will have 1 for indices that satisfy.
step_success = wolfe * (self.zero_alpha_list + 0.1)
_, m_idx = torch.max(step_success, dim=-2)
# The below can also be moved into approx wolfe?
if self.line_search_type != LineSearchType.APPROX_WOLFE:
step_success_w1 = wolfe_1 * (self.zero_alpha_list + 0.1)
_, m1_idx = torch.max(step_success_w1, dim=-2)
m_idx = torch.where(m_idx == 0, m1_idx, m_idx)
# From ICRA23, we know that noisy update helps, so we don't check for zero here
if self.line_search_type != LineSearchType.APPROX_WOLFE:
m_idx[m_idx == 0] = 1
m = m_idx.squeeze() + self.c_idx
g_x_1 = g_x.view(b * h, -1)
xs_1 = x_set.view(b * h, -1)
cs_1 = c.view(b * h, -1)
best_c = cs_1[m]
best_x = xs_1[m]
best_grad = g_x_1[m].view(b, 1, self.d_opt)
return best_x.detach(), best_c.detach(), best_grad.detach()
else:
if (
self._out_best_x is None
or self._out_best_x.shape[0] * self._out_best_x.shape[1]
!= x_set.shape[0] * x_set.shape[2]
):
self._out_best_x = torch.zeros(
(x_set.shape[0], x_set.shape[2]),
dtype=self.tensor_args.dtype,
device=self.tensor_args.device,
)
if (
self._out_best_c is None
or self._out_best_c.shape[0] * self._out_best_c.shape[1]
!= c.shape[0] * c.shape[2]
):
self._out_best_c = torch.zeros(
(c.shape[0], c.shape[2]),
dtype=self.tensor_args.dtype,
device=self.tensor_args.device,
)
if (
self._out_best_grad is None
or self._out_best_grad.shape[0] * self._out_best_grad.shape[1]
!= g_x.shape[0] * g_x.shape[2]
):
self._out_best_grad = torch.zeros(
(g_x.shape[0], g_x.shape[2]),
dtype=self.tensor_args.dtype,
device=self.tensor_args.device,
)
(best_x_n, best_c_n, best_grad_n) = wolfe_line_search(
self._out_best_x, # * 0.0,
self._out_best_c, # * 0.0,
self._out_best_grad, # * 0.0,
g_x,
x_set,
step_vec,
c,
self.c_idx,
self.c_1,
self.c_2,
self.zero_alpha_list,
self.line_search_type == LineSearchType.STRONG_WOLFE,
self.line_search_type == LineSearchType.APPROX_WOLFE,
)
# c_0 = c[:, 0:1]
# g_0 = g_x[:, 0:1]
best_grad_n = best_grad_n.view(b, 1, self.d_opt)
return best_x_n, best_c_n, best_grad_n
def _greedy_line_search(self, x, step_direction):
step_direction = step_direction.detach()
x_set = x.unsqueeze(-2) + self.alpha_list * step_direction.unsqueeze(-2)
x_set = self.clip_bounds(x_set)
x_set = x_set.detach()
x_set = x_set.detach().requires_grad_(True)
b, h, _ = x_set.shape
c, g_x = self._compute_cost_gradient(x_set)
best_c, m_idx = torch.min(c, dim=-2)
m = m_idx.squeeze() + self.c_idx
g_x = g_x.view(b * h, -1)
xs = x_set.view(b * h, -1)
best_x = xs[m]
best_grad = g_x[m].view(b, 1, self.d_opt)
return best_x, best_c, best_grad
def _armijo_line_search(self, x, step_direction):
step_direction = step_direction.detach()
step_vec = step_direction.unsqueeze(-2)
x_set = x.unsqueeze(-2) + self.alpha_list * step_vec
x_set = self.clip_bounds(x_set)
x_set = x_set.detach().requires_grad_(True)
b, h, _ = x_set.shape
c, g_x = self._compute_cost_gradient(x_set)
c_0 = c[:, 0:1]
g_0 = g_x[:, 0:1]
step_vec_T = step_vec.transpose(-1, -2)
g_step = g_0 @ step_vec_T
# condition 1:
armjio_1 = c <= c_0 + self.c_1 * self.zero_alpha_list * g_step # dot product
# get the last occurence of true (this will be the largest admissable alpha value):
# wolfe will have 1 for indices that satisfy.
# find the
step_success = armjio_1 * (self.zero_alpha_list + 0.1)
_, m_idx = torch.max(step_success, dim=-2)
m_idx[m_idx == 0] = 1
m = m_idx.squeeze() + self.c_idx
g_x = g_x.view(b * h, -1)
xs = x_set.view(b * h, -1)
cs = c.view(b * h, -1)
best_c = cs[m]
best_x = xs[m]
best_grad = g_x[m].view(b, 1, self.d_opt)
return best_x, best_c, best_grad
def _approx_line_search(self, x, step_direction):
if self.step_scale != 0.0 and self.step_scale != 1.0:
step_direction = self.scale_step_direction(step_direction)
if self.line_search_type == LineSearchType.GREEDY:
return self._greedy_line_search(x, step_direction)
elif self.line_search_type == LineSearchType.ARMIJO:
return self._armijo_line_search(x, step_direction)
elif self.line_search_type in [
LineSearchType.WOLFE,
LineSearchType.STRONG_WOLFE,
LineSearchType.APPROX_WOLFE,
]:
return self._wolfe_line_search(x, step_direction)
def check_convergence(self, cost):
above_threshold = cost > self.cost_convergence
above_threshold = torch.count_nonzero(above_threshold)
if above_threshold == 0:
self._opt_finished = True
return True
return False
def _update_best(self, q, grad_q, cost):
if self.use_cuda_update_best_kernel:
(self.best_cost, self.best_q, self.best_iteration) = update_best(
self.best_cost,
self.best_q,
self.best_iteration,
self.current_iteration,
cost,
q,
self.d_opt,
self.last_best,
self.cost_delta_threshold,
self.cost_relative_threshold,
)
else:
cost = cost.detach()
q = q.detach()
mask = cost < self.best_cost
self.best_cost.copy_(torch.where(mask, cost, self.best_cost))
mask = mask.view(mask.shape[0])
mask_q = mask.unsqueeze(-1).expand(-1, self.d_opt)
self.best_q.copy_(torch.where(mask_q, q, self.best_q))
def update_nenvs(self, n_envs):
self.l_vec = torch.ones(
(n_envs, self.num_particles, 1),
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
)
self.best_cost = (
torch.ones((n_envs, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype)
* 5000000.0
)
self.best_q = torch.zeros(
(n_envs, self.d_opt), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self.best_grad_q = torch.zeros(
(n_envs, 1, self.d_opt), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
# create list:
self.alpha_list = self.line_scale.repeat(n_envs, 1, 1)
self.zero_alpha_list = self.alpha_list[:, :, 0:1].contiguous()
h = self.alpha_list.shape[1]
self.c_idx = torch.arange(
0, n_envs * h, step=(h), device=self.tensor_args.device, dtype=torch.long
)
self.best_iteration = torch.zeros(
(n_envs), device=self.tensor_args.device, dtype=torch.int16
)
self.current_iteration = torch.zeros((1), device=self.tensor_args.device, dtype=torch.int16)
self.cu_opt_init = False
super().update_nenvs(n_envs)
def _initialize_opt_iters_graph(self, q, grad_q, shift_steps):
if self.use_cuda_graph:
self._create_opt_iters_graph(q, grad_q, shift_steps)
self.cu_opt_init = True
def _create_box_line_search(self, line_search_scale):
"""
Args:
line_search_scale (_type_): should have n values
"""
d = []
dof_vec = torch.zeros(
(self.d_opt), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
for i in line_search_scale:
d.append(dof_vec + i)
d = torch.stack(d, dim=0).unsqueeze(0)
return d
def _call_opt_iters_graph(self, q, grad_q):
if self.use_cuda_graph:
self._cu_opt_q_in.copy_(q.detach())
self._cu_opt_gq_in.copy_(grad_q.detach())
self.cu_opt_graph.replay()
return (
self._cu_opt_q.clone(),
self._cu_opt_cost.clone(),
self._cu_q.clone(),
self._cu_gq.clone(),
)
else:
return self._opt_iters(q, grad_q)
def _create_opt_iters_graph(self, q, grad_q, shift_steps):
# create a new stream:
self._cu_opt_q_in = q.detach().clone()
self._cu_opt_gq_in = grad_q.detach().clone()
s = torch.cuda.Stream(device=self.tensor_args.device)
s.wait_stream(torch.cuda.current_stream(device=self.tensor_args.device))
with torch.cuda.stream(s):
for _ in range(3):
self._cu_opt_q, self._cu_opt_cost, self._cu_q, self._cu_gq = self._opt_iters(
self._cu_opt_q_in, self._cu_opt_gq_in, shift_steps
)
torch.cuda.current_stream(device=self.tensor_args.device).wait_stream(s)
self.cu_opt_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cu_opt_graph, stream=s):
self._cu_opt_q, self._cu_opt_cost, self._cu_q, self._cu_gq = self._opt_iters(
self._cu_opt_q_in, self._cu_opt_gq_in, shift_steps
)
@torch.jit.script
def get_x_set_jit(step_vec, x, alpha_list, action_lows, action_highs):
# step_direction = step_direction.detach()
x_set = torch.clamp(x.unsqueeze(-2) + alpha_list * step_vec, action_lows, action_highs)
# x_set = x.unsqueeze(-2) + alpha_list * step_vec
return x_set
@torch.jit.script
def _armijo_line_search_tail_jit(c, g_x, step_direction, c_1, alpha_list, c_idx, x_set, d_opt):
c_0 = c[:, 0:1]
g_0 = g_x[:, 0:1]
step_vec = step_direction.unsqueeze(-2)
step_vec_T = step_vec.transpose(-1, -2)
g_step = g_0 @ step_vec_T
# condition 1:
armjio_1 = c <= c_0 + c_1 * alpha_list * g_step # dot product
# get the last occurence of true (this will be the largest admissable alpha value):
# wolfe will have 1 for indices that satisfy.
# find the
step_success = armjio_1 * (alpha_list[:, :, 0:1] + 0.1)
_, m_idx = torch.max(step_success, dim=-2)
m_idx[m_idx == 0] = 1
m = m_idx.squeeze() + c_idx
b, h, _ = x_set.shape
g_x = g_x.view(b * h, -1)
xs = x_set.view(b * h, -1)
cs = c.view(b * h, -1)
best_c = cs[m]
best_x = xs[m]
best_grad = g_x[m].view(b, 1, d_opt)
return (best_x, best_c, best_grad)
@torch.jit.script
def _wolfe_search_tail_jit(c, g_x, x_set, m, d_opt: int):
b, h, _ = x_set.shape
g_x = g_x.view(b * h, -1)
xs = x_set.view(b * h, -1)
cs = c.view(b * h, -1)
best_c = cs[m]
best_x = xs[m]
best_grad = g_x[m].view(b, 1, d_opt)
return (best_x, best_c, best_grad)
@torch.jit.script
def scale_action(dx, action_step_max):
scale_value = torch.max(torch.abs(dx) / action_step_max, dim=-1, keepdim=True)[0]
scale_value = torch.clamp(scale_value, 1.0)
dx_scaled = dx / scale_value
return dx_scaled
@torch.jit.script
def check_convergence(
best_iteration: torch.Tensor, current_iteration: torch.Tensor, last_best: int
) -> bool:
success = False
if torch.max(best_iteration).item() <= (-1.0 * (last_best)):
success = True
return success

202
src/curobo/opt/opt_base.py Normal file
View File

@@ -0,0 +1,202 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
import time
from abc import abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
# Third Party
import torch
import torch.autograd.profiler as profiler
# CuRobo
from curobo.rollout.rollout_base import Goal, RolloutBase
from curobo.types.base import TensorDeviceType
from curobo.util.logger import log_info
from curobo.util.torch_utils import is_cuda_graph_available
@dataclass
class OptimizerConfig:
d_action: int
action_lows: List[float]
action_highs: List[float]
horizon: int
n_iters: int
rollout_fn: RolloutBase
tensor_args: TensorDeviceType
use_cuda_graph: bool
store_debug: bool
debug_info: Any
cold_start_n_iters: int
num_particles: Union[int, None]
n_envs: int
sync_cuda_time: bool
use_coo_sparse: bool
def __post_init__(self):
object.__setattr__(self, "action_highs", self.tensor_args.to_device(self.action_highs))
object.__setattr__(self, "action_lows", self.tensor_args.to_device(self.action_lows))
# check cuda graph version:
if self.use_cuda_graph:
self.use_cuda_graph = is_cuda_graph_available()
if self.num_particles is None:
self.num_particles = 1
@staticmethod
def create_data_dict(
data_dict: Dict,
rollout_fn: RolloutBase,
tensor_args: TensorDeviceType = TensorDeviceType(),
child_dict: Optional[Dict] = None,
):
if child_dict is None:
child_dict = deepcopy(data_dict)
child_dict["d_action"] = rollout_fn.d_action
child_dict["action_lows"] = rollout_fn.action_bound_lows
child_dict["action_highs"] = rollout_fn.action_bound_highs
child_dict["rollout_fn"] = rollout_fn
child_dict["tensor_args"] = tensor_args
child_dict["horizon"] = rollout_fn.horizon
if "num_particles" not in child_dict:
child_dict["num_particles"] = 1
return child_dict
class Optimizer(OptimizerConfig):
def __init__(self, config: Optional[OptimizerConfig] = None) -> None:
if config is not None:
super().__init__(**vars(config))
self.opt_dt = 0.0
self.COLD_START = True
self.update_nenvs(self.n_envs)
self._batch_goal = None
self._rollout_list = None
self.debug = []
self.debug_cost = []
def optimize(self, opt_tensor: torch.Tensor, shift_steps=0, n_iters=None) -> torch.Tensor:
if self.COLD_START:
n_iters = self.cold_start_n_iters
self.COLD_START = False
st_time = time.time()
out = self._optimize(opt_tensor, shift_steps, n_iters)
if self.sync_cuda_time:
torch.cuda.synchronize()
self.opt_dt = time.time() - st_time
return out
@abstractmethod
def _optimize(self, opt_tensor: torch.Tensor, shift_steps=0, n_iters=None) -> torch.Tensor:
pass
def _shift(self, shift_steps=0):
"""
Shift the variables in the solver to hotstart the next timestep
"""
return
def update_params(self, goal: Goal):
with profiler.record_function("OptBase/batch_goal"):
if self._batch_goal is not None:
self._batch_goal.copy_(goal, update_idx_buffers=True) # why True?
else:
self._batch_goal = goal.repeat_seeds(self.num_particles)
self.rollout_fn.update_params(self._batch_goal)
def reset(self):
"""
Reset the controller
"""
self.rollout_fn.reset()
self.debug = []
self.debug_cost = []
# self.COLD_START = True
def update_nenvs(self, n_envs):
assert n_envs > 0
self._update_env_kernel(n_envs, self.num_particles)
self.n_envs = n_envs
def _update_env_kernel(self, n_envs, num_particles):
log_info(
"Updating env kernel [n_envs: "
+ str(n_envs)
+ " , num_particles: "
+ str(num_particles)
+ " ]"
)
self.env_col = torch.arange(
0, n_envs, step=1, dtype=torch.long, device=self.tensor_args.device
)
self.n_select_ = torch.tensor(
[x * n_envs + x for x in range(n_envs)],
device=self.tensor_args.device,
dtype=torch.long,
)
# create sparse tensor:
sparse_indices = []
for i in range(n_envs):
sparse_indices.extend([[i * num_particles + x, i] for x in range(num_particles)])
sparse_values = torch.ones(len(sparse_indices))
sparse_indices = torch.tensor(sparse_indices)
if self.use_coo_sparse:
self.env_kernel_ = torch.sparse_coo_tensor(
sparse_indices.t(),
sparse_values,
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
)
else:
self.env_kernel_ = torch.sparse_coo_tensor(
sparse_indices.t(),
sparse_values,
device="cpu",
dtype=self.tensor_args.dtype,
)
self.env_kernel_ = self.env_kernel_.to_dense().to(device=self.tensor_args.device)
self._env_seeds = self.num_particles
def get_nenv_tensor(self, x):
"""This function takes an input tensor of shape (n_env,....) and converts it into
(n_particles,...)
"""
# if x.shape[0] != self.n_envs:
# x_env = x.unsqueeze(0).repeat(self.n_envs, 1)
# else:
# x_env = x
# create a tensor
nx_env = self.env_kernel_ @ x
return nx_env
def reset_seed(self):
return True
def reset_cuda_graph(self):
if self.use_cuda_graph:
self.cu_opt_init = False
else:
log_info("Cuda Graph was not enabled")
self._batch_goal = None
self.rollout_fn.reset_cuda_graph()
def get_all_rollout_instances(self) -> List[RolloutBase]:
if self._rollout_list is None:
self._rollout_list = [self.rollout_fn]
return self._rollout_list

View File

@@ -0,0 +1,14 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
"""
This module contains particle-based optimization solvers.
"""

View File

@@ -0,0 +1,76 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from dataclasses import dataclass
from typing import Optional
# Third Party
import torch
# CuRobo
from curobo.opt.particle.parallel_mppi import CovType, ParallelMPPI, ParallelMPPIConfig
@dataclass
class ParallelESConfig(ParallelMPPIConfig):
learning_rate: float = 0.1
class ParallelES(ParallelMPPI, ParallelESConfig):
def __init__(self, config: Optional[ParallelESConfig] = None):
if config is not None:
ParallelESConfig.__init__(self, **vars(config))
ParallelMPPI.__init__(self)
def _compute_mean(self, w, actions):
if self.cov_type not in [CovType.SIGMA_I, CovType.DIAG_A]:
raise NotImplementedError()
new_mean = compute_es_mean(
w, actions, self.mean_action, self.full_inv_cov, self.num_particles, self.learning_rate
)
# get the new means from here
# use Evolutionary Strategy Mean here:
return new_mean
def _exp_util(self, total_costs):
w = calc_exp(total_costs)
return w
@torch.jit.script
def calc_exp(
total_costs,
):
total_costs = -1.0 * total_costs
# total_costs[torch.abs(total_costs) < 5.0] == 0.0
w = (total_costs - torch.mean(total_costs, keepdim=True, dim=-1)) / torch.std(
total_costs, keepdim=True, dim=-1
)
return w
@torch.jit.script
def compute_es_mean(
w, actions, mean_action, full_inv_cov, num_particles: int, learning_rate: float
):
std_w = torch.std(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
# std_w = torch.sqrt(torch.sum(w - torch.mean(w, dim=[1,2,3], keepdim=True))/float(num_particles))
a_og = (actions - mean_action.unsqueeze(1)) / std_w
weighted_seq = (
(torch.sum(w * a_og, dim=-3, keepdim=True)) @ (full_inv_cov / num_particles)
).squeeze(1)
# weighted_seq[weighted_seq != weighted_seq] = 0.0
# 0.01 is the learning rate:
new_mean = mean_action + learning_rate * weighted_seq # torch.clamp(weighted_seq, -1000, 1000)
return new_mean

View File

@@ -0,0 +1,615 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Optional
# Third Party
import torch
import torch.autograd.profiler as profiler
# CuRobo
from curobo.opt.particle.particle_opt_base import ParticleOptBase, ParticleOptConfig, SampleMode
from curobo.opt.particle.particle_opt_utils import (
SquashType,
cost_to_go,
gaussian_entropy,
matrix_cholesky,
scale_ctrl,
)
from curobo.rollout.rollout_base import RolloutBase, Trajectory
from curobo.types.base import TensorDeviceType
from curobo.types.robot import State
from curobo.util.logger import log_info
from curobo.util.sample_lib import HaltonSampleLib, SampleConfig, SampleLib
from curobo.util.tensor_util import copy_tensor
class BaseActionType(Enum):
REPEAT = 0
NULL = 1
RANDOM = 2
class CovType(Enum):
SIGMA_I = 0
DIAG_A = 1
FULL_A = 2
FULL_HA = 3
@dataclass
class ParallelMPPIConfig(ParticleOptConfig):
init_mean: float
init_cov: float
base_action: BaseActionType
step_size_mean: float
step_size_cov: float
null_act_frac: float
squash_fn: SquashType
cov_type: CovType
sample_params: SampleConfig
update_cov: bool
random_mean: bool
beta: float
alpha: float
gamma: float
kappa: float
sample_per_env: bool
def __post_init__(self):
self.init_cov = self.tensor_args.to_device(self.init_cov).unsqueeze(0)
self.init_mean = self.tensor_args.to_device(self.init_mean).clone()
return super().__post_init__()
@staticmethod
@profiler.record_function("parallel_mppi_config/create_data_dict")
def create_data_dict(
data_dict: Dict,
rollout_fn: RolloutBase,
tensor_args: TensorDeviceType = TensorDeviceType(),
child_dict: Optional[Dict] = None,
):
if child_dict is None:
child_dict = deepcopy(data_dict)
child_dict = ParticleOptConfig.create_data_dict(
data_dict, rollout_fn, tensor_args, child_dict
)
child_dict["base_action"] = BaseActionType[child_dict["base_action"]]
child_dict["squash_fn"] = SquashType[child_dict["squash_fn"]]
child_dict["cov_type"] = CovType[child_dict["cov_type"]]
child_dict["sample_params"]["d_action"] = rollout_fn.d_action
child_dict["sample_params"]["horizon"] = child_dict["horizon"]
child_dict["sample_params"]["tensor_args"] = tensor_args
child_dict["sample_params"] = SampleConfig(**child_dict["sample_params"])
# init_mean:
if "init_mean" not in child_dict:
child_dict["init_mean"] = rollout_fn.get_init_action_seq()
return child_dict
class ParallelMPPI(ParticleOptBase, ParallelMPPIConfig):
@profiler.record_function("parallel_mppi/init")
def __init__(self, config: Optional[ParallelMPPIConfig] = None):
if config is not None:
ParallelMPPIConfig.__init__(self, **vars(config))
ParticleOptBase.__init__(self)
self.sample_lib = SampleLib(self.sample_params)
self._sample_set = None
self._sample_iter = None
# initialize covariance types:
if self.cov_type == CovType.FULL_HA:
self.I = torch.eye(
self.horizon * self.d_action,
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
)
else: # AxA
self.I = torch.eye(
self.d_action, device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self.Z_seq = torch.zeros(
1,
self.horizon,
self.d_action,
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
)
self.delta = None
self.mean_action = None
self.act_seq = None
self.cov_action = None
self.best_traj = None
self.scale_tril = None
self.visual_traj = None
if self.debug_info is not None and "visual_traj" in self.debug_info.keys():
self.visual_traj = self.debug_info["visual_traj"]
self.top_values = None
self.top_idx = None
self.top_trajs = None
self.mean_lib = HaltonSampleLib(
SampleConfig(
self.horizon,
self.d_action,
tensor_args=self.tensor_args,
**{"fixed_samples": False, "seed": 2567, "filter_coeffs": None}
)
)
self.reset_distribution()
self.update_samples()
self._use_cuda_graph = False
self._init_cuda_graph = False
self.info = dict(rollout_time=0.0, entropy=[])
self._batch_size = -1
self._store_debug = False
def get_rollouts(self):
return self.top_trajs
def reset_distribution(self):
"""
Reset control distribution
"""
self.reset_mean()
self.reset_covariance()
def _compute_total_cost(self, costs):
"""
Calculate weights using exponential utility
"""
# cost_seq = self.gamma_seq * costs
# cost_seq = torch.sum(cost_seq, dim=-1, keepdim=False) / self.gamma_seq[..., 0]
# print(self.gamma_seq.shape, costs.shape)
cost_seq = jit_compute_total_cost(self.gamma_seq, costs)
return cost_seq
def _exp_util(self, total_costs):
w = jit_calculate_exp_util(self.beta, total_costs)
# w = torch.softmax((-1.0 / self.beta) * total_costs, dim=-1)
return w
def _compute_mean(self, w, actions):
# get the new means from here
new_mean = torch.sum(w * actions, dim=-3)
return new_mean
def _compute_covariance(self, w, actions):
if not self.update_cov:
return
# w = w.squeeze(-1).squeeze(-1)
# w = w[0, :]
if self.cov_type == CovType.SIGMA_I:
delta_actions = actions - self.mean_action.unsqueeze(-3)
# weighted_delta = w * (delta ** 2).T
# cov_update = torch.ean(torch.sum(weighted_delta.T, dim=0))
# print(cov_update.shape, self.cov_action)
weighted_delta = w * (delta_actions**2)
cov_update = torch.mean(
torch.sum(torch.sum(weighted_delta, dim=-2), dim=-1), dim=-1, keepdim=True
)
# raise NotImplementedError("Need to implement covariance update of form sigma*I")
elif self.cov_type == CovType.DIAG_A:
# Diagonal covariance of size AxA
# n, b, h, d = delta_actions.shape
# delta_actions = delta_actions.view(n,b,h*d)
# weighted_delta = w * (delta_actions**2)
# weighted_delta =
# sum across horizon and mean across particles:
# cov_update = torch.diag(torch.mean(torch.sum(weighted_delta.T , dim=0), dim=0))
# cov_update = torch.mean(torch.sum(weighted_delta, dim=-2), dim=-2).unsqueeze(
# -2
# ) # .expand(-1,-1,-1)
cov_update = jit_diag_a_cov_update(w, actions, self.mean_action)
elif self.cov_type == CovType.FULL_A:
delta_actions = actions - self.mean_action.unsqueeze(-3)
delta = delta_actions[0, ...]
raise NotImplementedError
elif self.cov_type == CovType.FULL_HA:
delta_actions = actions - self.mean_action.unsqueeze(-3)
delta = delta_actions[0, ...]
weighted_delta = (
torch.sqrt(w) * delta.view(delta.shape[0], delta.shape[1] * delta.shape[2]).T
) # .unsqueeze(-1)
cov_update = torch.matmul(weighted_delta, weighted_delta.T)
else:
raise ValueError("Unidentified covariance type in update_distribution")
return cov_update
def _update_cov_scale(self):
if not self.update_cov:
return
if self.cov_type == CovType.SIGMA_I:
self.scale_tril = torch.sqrt(self.cov_action)
elif self.cov_type == CovType.DIAG_A:
self.scale_tril.copy_(torch.sqrt(self.cov_action))
elif self.cov_type == CovType.FULL_A:
self.scale_tril = matrix_cholesky(self.cov_action)
elif self.cov_type == CovType.FULL_HA:
raise NotImplementedError
@torch.no_grad()
def _update_distribution(self, trajectories: Trajectory):
costs = trajectories.costs
actions = trajectories.actions
total_costs = self._compute_total_cost(costs)
# Let's reshape to n_envs now:
# first find the means before doing exponential utility:
w = self._exp_util(total_costs)
# Update best action
if self.sample_mode == SampleMode.BEST:
best_idx = torch.argmax(w, dim=-1)
self.best_traj.copy_(actions[self.env_col, best_idx])
if self.store_rollouts and self.visual_traj is not None:
vis_seq = getattr(trajectories.state, self.visual_traj)
top_values, top_idx = torch.topk(total_costs, 20, dim=1)
self.top_values = top_values
self.top_idx = top_idx
top_trajs = torch.index_select(vis_seq, 0, top_idx[0])
for i in range(1, top_idx.shape[0]):
trajs = torch.index_select(vis_seq, 0, top_idx[i] + (self.particles_per_env * i))
top_trajs = torch.cat((top_trajs, trajs), dim=0)
if self.top_trajs is None or top_trajs.shape != self.top_trajs:
self.top_trajs = top_trajs
else:
self.top_trajs.copy_(top_trajs)
w = w.unsqueeze(-1).unsqueeze(-1)
new_mean = self._compute_mean(w, actions)
# print(new_mean)
if self.update_cov:
cov_update = self._compute_covariance(w, actions)
new_cov = jit_blend_cov(self.cov_action, cov_update, self.step_size_cov, self.kappa)
self.cov_action.copy_(new_cov)
self._update_cov_scale()
new_mean = jit_blend_mean(self.mean_action, new_mean, self.step_size_mean)
self.mean_action.copy_(new_mean)
@torch.no_grad()
def sample_actions(self, init_act):
delta = torch.index_select(self._sample_set, 0, self._sample_iter).squeeze(0)
if not self.sample_params.fixed_samples:
self._sample_iter[:] += 1
self._sample_iter_n += 1
if self._sample_iter_n >= self.n_iters:
self._sample_iter_n = 0
self._sample_iter[:] = 0
log_info(
"Resetting sample iterations in particle opt base to 0, this is okay during graph capture"
)
scaled_delta = delta * self.full_scale_tril
act_seq = self.mean_action.unsqueeze(-3) + scaled_delta
cat_list = [act_seq]
if self.neg_per_env > 0:
neg_action = -1.0 * self.mean_action
neg_act_seqs = neg_action.unsqueeze(-3).expand(-1, self.neg_per_env, -1, -1)
cat_list.append(neg_act_seqs)
if self.null_per_env > 0:
cat_list.append(
self.null_act_seqs[: self.null_per_env].unsqueeze(0).expand(self.n_envs, -1, -1, -1)
)
act_seq = torch.cat(
(cat_list),
dim=-3,
)
act_seq = act_seq.reshape(self.total_num_particles, self.horizon, self.d_action)
act_seq = scale_ctrl(act_seq, self.action_lows, self.action_highs, squash_fn=self.squash_fn)
# if not copy_tensor(act_seq, self.act_seq):
# self.act_seq = act_seq
return act_seq # self.act_seq
def update_seed(self, init_act):
self.update_init_mean(init_act)
def update_init_mean(self, init_mean):
# update mean:
# init_mean = init_mean.clone()
if init_mean.shape[0] != self.n_envs:
init_mean = init_mean.expand(self.n_envs, -1, -1)
if not copy_tensor(init_mean, self.mean_action):
self.mean_action = init_mean.clone()
if not copy_tensor(init_mean, self.best_traj):
self.best_traj = init_mean.clone()
def reset_mean(self):
with profiler.record_function("mppi/reset_mean"):
if self.random_mean:
mean = self.mean_lib.get_samples([self.n_envs])
self.update_init_mean(mean)
else:
self.update_init_mean(self.init_mean)
def reset_covariance(self):
with profiler.record_function("mppi/reset_cov"):
# init_cov can either be a single value, or n_envs x 1 or n_envs x d_action
if self.cov_type == CovType.SIGMA_I:
# init_cov can either be a single value, or n_envs x 1
self.cov_action = self.init_cov
if self.init_cov.shape[0] != self.n_envs:
self.cov_action = self.init_cov.unsqueeze(0).expand(self.n_envs, -1)
self.inv_cov_action = 1.0 / self.cov_action
a = torch.sqrt(self.cov_action)
if not copy_tensor(a, self.scale_tril):
self.scale_tril = a
elif self.cov_type == CovType.DIAG_A:
# init_cov can either be a single value, or n_envs x 1 or n_envs x 7
init_cov = self.init_cov.clone()
# if(init_cov.shape[-1] != self.d_action):
if len(init_cov.shape) == 1:
init_cov = init_cov.unsqueeze(-1).expand(-1, self.d_action)
if len(init_cov.shape) == 2 and init_cov.shape[-1] != self.d_action:
init_cov = init_cov.expand(-1, self.d_action)
init_cov = init_cov.unsqueeze(1)
if init_cov.shape[0] != self.n_envs:
init_cov = init_cov.expand(self.n_envs, -1, -1)
if not copy_tensor(init_cov.clone(), self.cov_action):
self.cov_action = init_cov.clone()
self.inv_cov_action = 1.0 / self.cov_action
a = torch.sqrt(self.cov_action)
if not copy_tensor(a, self.scale_tril):
self.scale_tril = a
else:
raise ValueError("Unidentified covariance type in update_distribution")
def _get_action_seq(self, mode: SampleMode):
if mode == SampleMode.MEAN:
act_seq = self.mean_action # .clone() # [self.mean_idx]#.clone()
elif mode == SampleMode.SAMPLE:
delta = self.generate_noise(
shape=torch.Size((1, self.horizon)), base_seed=self.seed + 123 * self.num_steps
)
act_seq = self.mean_action + torch.matmul(delta, self.full_scale_tril)
elif mode == SampleMode.BEST:
act_seq = self.best_traj # [self.mean_idx]
else:
raise ValueError("Unidentified sampling mode in get_next_action")
# act_seq = scale_ctrl(act_seq, self.action_lows, self.action_highs, squash_fn=self.squash_fn)
return act_seq
def generate_noise(self, shape, base_seed=None):
"""
Generate correlated noisy samples using autoregressive process
"""
delta = self.sample_lib.get_samples(sample_shape=shape, seed=base_seed)
return delta
@property
def full_scale_tril(self):
"""Returns the full scale tril
Returns:
Tensor: dimension is (d_action, d_action)
"""
if self.cov_type == CovType.SIGMA_I:
return self.scale_tril.unsqueeze(-2).unsqueeze(-2).expand(-1, -1, self.horizon, -1)
elif self.cov_type == CovType.DIAG_A:
return self.scale_tril.unsqueeze(-2).expand(-1, -1, self.horizon, -1) # .cl
elif self.cov_type == CovType.FULL_A:
return self.scale_tril
elif self.cov_type == CovType.FULL_HA:
return self.scale_tril
def _calc_val(self, trajectories: Trajectory):
costs = trajectories.costs
actions = trajectories.actions
delta = actions - self.mean_action.unsqueeze(0)
traj_costs = cost_to_go(costs, self.gamma_seq)[:, 0]
control_costs = self._control_costs(delta)
total_costs = traj_costs + self.beta * control_costs
val = -self.beta * torch.logsumexp((-1.0 / self.beta) * total_costs)
return val
def reset(self):
self.reset_distribution()
self._sample_iter[:] = 0
self._sample_iter_n = 0
self.update_samples() # this helps in restarting optimization
super().reset()
@property
def squashed_mean(self):
return scale_ctrl(
self.mean_action, self.action_lows, self.action_highs, squash_fn=self.squash_fn
)
@property
def full_cov(self):
if self.cov_type == CovType.SIGMA_I:
return self.cov_action * self.I
elif self.cov_type == CovType.DIAG_A:
return torch.diag(self.cov_action)
elif self.cov_type == CovType.FULL_A:
return self.cov_action
elif self.cov_type == CovType.FULL_HA:
return self.cov_action
@property
def full_inv_cov(self):
if self.cov_type == CovType.SIGMA_I:
return self.inv_cov_action * self.I
elif self.cov_type == CovType.DIAG_A:
return torch.diag_embed(self.inv_cov_action)
elif self.cov_type == CovType.FULL_A:
return self.inv_cov_action
elif self.cov_type == CovType.FULL_HA:
return self.inv_cov_action
@property
def full_scale_tril(self):
if self.cov_type == CovType.SIGMA_I:
return (
self.scale_tril.unsqueeze(-2).unsqueeze(-2).expand(-1, -1, self.horizon, -1)
) # .cl
elif self.cov_type == CovType.DIAG_A:
return self.scale_tril.unsqueeze(-2).expand(-1, -1, self.horizon, -1) # .cl
elif self.cov_type == CovType.FULL_A:
return self.scale_tril
elif self.cov_type == CovType.FULL_HA:
return self.scale_tril
@property
def entropy(self):
ent_L = gaussian_entropy(L=self.full_scale_tril)
return ent_L
def reset_seed(self):
self.sample_lib = SampleLib(self.sample_params)
self.mean_lib = HaltonSampleLib(
SampleConfig(
self.horizon,
self.d_action,
tensor_args=self.tensor_args,
**{"fixed_samples": False, "seed": 2567, "filter_coeffs": None}
)
)
# resample if not fixed samples:
self.update_samples()
super().reset_seed()
def update_samples(self):
with profiler.record_function("mppi/update_samples"):
if self.sample_params.fixed_samples:
n_iters = 1
else:
n_iters = self.n_iters
if self.sample_per_env:
s_set = (
self.sample_lib.get_samples(
sample_shape=[self.sampled_particles_per_env * self.n_envs * n_iters],
base_seed=self.seed,
)
.view(
n_iters,
self.n_envs,
self.sampled_particles_per_env,
self.horizon,
self.d_action,
)
.clone()
)
else:
s_set = self.sample_lib.get_samples(
sample_shape=[n_iters * (self.sampled_particles_per_env)],
base_seed=self.seed,
)
s_set = s_set.view(
n_iters, 1, self.sampled_particles_per_env, self.horizon, self.d_action
)
s_set = s_set.repeat(1, self.n_envs, 1, 1, 1).clone()
s_set[:, :, -1, :, :] = 0.0
if not copy_tensor(s_set, self._sample_set):
log_info("ParallelMPPI: Updating sample set")
self._sample_set = s_set
if self._sample_iter is None:
log_info("ParallelMPPI: Resetting sample iterations") # , sample_iter.shape)
self._sample_iter = torch.zeros(
(1), dtype=torch.long, device=self.tensor_args.device
)
else:
self._sample_iter[:] = 0
# if not copy_tensor(sample_iter, self._sample_iter):
# log_info("ParallelMPPI: Resetting sample iterations") # , sample_iter.shape)
# self._sample_iter = sample_iter
self._sample_iter_n = 0
@torch.no_grad()
def generate_rollouts(self, init_act=None):
"""
Samples a batch of actions, rolls out trajectories for each particle
and returns the resulting observations, costs,
actions
Parameters
----------
state : dict or np.ndarray
Initial state to set the simulation env to
"""
return super().generate_rollouts(init_act)
@torch.jit.script
def jit_calculate_exp_util(beta: float, total_costs):
w = torch.softmax((-1.0 / beta) * total_costs, dim=-1)
return w
@torch.jit.script
def jit_compute_total_cost(gamma_seq, costs):
cost_seq = gamma_seq * costs
cost_seq = torch.sum(cost_seq, dim=-1, keepdim=False) / gamma_seq[..., 0]
return cost_seq
@torch.jit.script
def jit_diag_a_cov_update(w, actions, mean_action):
delta_actions = actions - mean_action.unsqueeze(-3)
weighted_delta = w * (delta_actions**2)
# weighted_delta =
# sum across horizon and mean across particles:
# cov_update = torch.diag(torch.mean(torch.sum(weighted_delta.T , dim=0), dim=0))
cov_update = torch.mean(torch.sum(weighted_delta, dim=-2), dim=-2).unsqueeze(-2)
return cov_update
@torch.jit.script
def jit_blend_cov(cov_action, cov_update, step_size_cov: float, kappa: float):
new_cov = (1.0 - step_size_cov) * cov_action + step_size_cov * cov_update + kappa
return new_cov
@torch.jit.script
def jit_blend_mean(mean_action, new_mean, step_size_mean: float):
mean_update = (1.0 - step_size_mean) * mean_action + step_size_mean * new_mean
return mean_update

View File

@@ -0,0 +1,302 @@
#!/usr/bin/env python
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from abc import abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Optional
# Third Party
import torch
import torch.autograd.profiler as profiler
# CuRobo
from curobo.opt.opt_base import Optimizer, OptimizerConfig
from curobo.rollout.rollout_base import RolloutBase, Trajectory
from curobo.types.base import TensorDeviceType
from curobo.types.tensor import T_BHDOF_float, T_HDOF_float
from curobo.util.logger import log_error, log_info
class SampleMode(Enum):
MEAN = 0
BEST = 1
SAMPLE = 2
@dataclass
class ParticleOptInfo:
info: Optional[Dict] = None
@dataclass
class ParticleOptConfig(OptimizerConfig):
gamma: float
sample_mode: SampleMode
seed: int
calculate_value: bool
store_rollouts: bool
def __post_init__(self):
object.__setattr__(self, "action_highs", self.tensor_args.to_device(self.action_highs))
object.__setattr__(self, "action_lows", self.tensor_args.to_device(self.action_lows))
if self.calculate_value and self.use_cuda_graph:
log_error("Cannot calculate_value when cuda graph is enabled")
return super().__post_init__()
@staticmethod
def create_data_dict(
data_dict: Dict,
rollout_fn: RolloutBase,
tensor_args: TensorDeviceType = TensorDeviceType(),
child_dict: Optional[Dict] = None,
):
if child_dict is None:
child_dict = deepcopy(data_dict)
child_dict = OptimizerConfig.create_data_dict(
data_dict, rollout_fn, tensor_args, child_dict
)
child_dict["sample_mode"] = SampleMode[child_dict["sample_mode"]]
if "calculate_value" not in child_dict:
child_dict["calculate_value"] = False
if "store_rollouts" not in child_dict:
child_dict["store_rollouts"] = False
return child_dict
class ParticleOptBase(Optimizer, ParticleOptConfig):
"""Base class for sampling based controllers."""
@profiler.record_function("particle_opt/init")
def __init__(
self,
config: Optional[ParticleOptConfig] = None,
):
if config is not None:
super().__init__(**vars(config))
Optimizer.__init__(self)
self.gamma_seq = torch.cumprod(
torch.tensor([1.0] + [self.gamma] * (self.horizon - 1)), dim=0
).reshape(1, self.horizon)
self.gamma_seq = self.tensor_args.to_device(self.gamma_seq)
self.num_steps = 0
self.trajectories = None
self.cu_opt_init = False
self.info = ParticleOptInfo()
self.update_num_particles_per_env(self.num_particles)
@abstractmethod
def _get_action_seq(self, mode=SampleMode):
"""
Get action sequence to execute on the system based
on current control distribution
Args:
mode : {'mean', 'sample'}
how to choose action to be executed
'mean' plays mean action and
'sample' samples from the distribution
"""
pass
@abstractmethod
def sample_actions(self, init_act: T_BHDOF_float):
"""
Sample actions from current control distribution
"""
raise NotImplementedError("sample_actions funtion not implemented")
def update_seed(self, init_act):
raise NotImplementedError
@abstractmethod
def _update_distribution(self, trajectories: Trajectory):
"""
Update current control distribution using
rollout trajectories
Args:
trajectories : dict
Rollout trajectories. Contains the following fields
observations : torch.tensor
observations along rollouts
actions : torch.tensor
actions sampled from control distribution along rollouts
costs : torch.tensor
step costs along rollouts
"""
pass
def reset(self):
"""
Reset the optimizer
"""
self.num_steps = 0
# self.rollout_fn.reset()
super().reset()
@abstractmethod
def _calc_val(self, trajectories: Trajectory):
"""
Calculate value of state given
rollouts from a policy
"""
pass
def check_convergence(self):
"""
Checks if controller has converged
Returns False by default
"""
return False
def generate_rollouts(self, init_act=None):
"""
Samples a batch of actions, rolls out trajectories for each particle
and returns the resulting observations, costs,
actions
Parameters
----------
state : dict or np.ndarray
Initial state to set the simulation env to
"""
act_seq = self.sample_actions(init_act)
trajectories = self.rollout_fn(act_seq)
return trajectories
def _optimize(self, init_act: torch.Tensor, shift_steps=0, n_iters=None):
"""
Optimize for best action at current state
Parameters
----------
state : torch.Tensor
state to calculate optimal action from
calc_val : bool
If true, calculate the optimal value estimate
of the state along with action
Returns
-------
action : torch.Tensor
next action to execute
value: float
optimal value estimate (default: 0.)
info: dict
dictionary with side-information
"""
n_iters = n_iters if n_iters is not None else self.n_iters
# create cuda graph:
if self.use_cuda_graph and self.cu_opt_init:
curr_action_seq = self._call_cuda_opt_iters(init_act)
else:
curr_action_seq = self._run_opt_iters(
init_act, n_iters=n_iters, shift_steps=shift_steps
)
if self.use_cuda_graph:
if not self.cu_opt_init:
self._initialize_cuda_graph(init_act, shift_steps=shift_steps)
self.num_steps += 1
if self.calculate_value:
trajectories = self.generate_rollouts(init_act)
value = self._calc_val(trajectories)
self.info["value"] = value
# print(self.act_seq)
return curr_action_seq
def _initialize_cuda_graph(self, init_act: T_HDOF_float, shift_steps=0):
log_info("ParticleOptBase: Creating Cuda Graph")
self._cu_act_in = init_act.detach().clone()
# create a new stream:
s = torch.cuda.Stream(device=self.tensor_args.device)
s.wait_stream(torch.cuda.current_stream(device=self.tensor_args.device))
with torch.cuda.stream(s):
for _ in range(3):
self._cu_act_seq = self._run_opt_iters(self._cu_act_in, shift_steps=shift_steps)
torch.cuda.current_stream(device=self.tensor_args.device).wait_stream(s)
self.reset()
self.cu_opt_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cu_opt_graph, stream=s):
self._cu_act_seq = self._run_opt_iters(self._cu_act_in, shift_steps=shift_steps)
self.cu_opt_init = True
def _call_cuda_opt_iters(self, init_act: T_HDOF_float):
self._cu_act_in.copy_(init_act.detach())
self.cu_opt_graph.replay()
return self._cu_act_seq.detach().clone() # .clone()
def _run_opt_iters(self, init_act: T_HDOF_float, shift_steps=0, n_iters=None):
n_iters = n_iters if n_iters is not None else self.n_iters
self._shift(shift_steps)
self.update_seed(init_act)
if not self.use_cuda_graph and self.store_debug:
self.debug.append(self._get_action_seq(mode=self.sample_mode).clone())
for _ in range(n_iters):
# generate random simulated trajectories
trajectory = self.generate_rollouts()
trajectory.actions = trajectory.actions.view(
self.n_envs, self.particles_per_env, self.horizon, self.d_action
)
trajectory.costs = trajectory.costs.view(
self.n_envs, self.particles_per_env, self.horizon
)
with profiler.record_function("mpc/mppi/update_distribution"):
self._update_distribution(trajectory)
if not self.use_cuda_graph and self.store_debug:
self.debug.append(self._get_action_seq(mode=self.sample_mode).clone())
self.debug_cost.append(
torch.min(torch.sum(trajectory.costs, dim=-1), dim=-1)[0].unsqueeze(-1).clone()
)
curr_action_seq = self._get_action_seq(mode=self.sample_mode)
return curr_action_seq
def update_nenvs(self, n_envs):
assert n_envs > 0
self.total_num_particles = n_envs * self.num_particles
self.cu_opt_init = False
super().update_nenvs(n_envs)
def update_num_particles_per_env(self, num_particles_per_env):
self.null_per_env = round(int(self.null_act_frac * num_particles_per_env * 0.5))
self.neg_per_env = (
round(int(self.null_act_frac * num_particles_per_env)) - self.null_per_env
)
self.sampled_particles_per_env = (
num_particles_per_env - self.null_per_env - self.neg_per_env
)
self.particles_per_env = num_particles_per_env
if self.null_per_env > 0:
self.null_act_seqs = torch.zeros(
self.null_per_env,
self.horizon,
self.d_action,
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
)

View File

@@ -0,0 +1,306 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from enum import Enum
# Third Party
import numpy as np
import torch
import torch.autograd.profiler as profiler
# CuRobo
from curobo.types.base import TensorDeviceType
class SquashType(Enum):
CLAMP = 0
CLAMP_RESCALE = 1
TANH = 2
IDENTITY = 3
def scale_ctrl(ctrl, action_lows, action_highs, squash_fn: SquashType = SquashType.CLAMP):
if len(ctrl.shape) == 1:
ctrl = ctrl.unsqueeze(0).unsqueeze(-1)
# ctrl = ctrl[np.newaxis, :, np.newaxis] # TODO: does this work with gpu pytorch?
act_half_range = (action_highs - action_lows) / 2.0
act_mid_range = (action_highs + action_lows) / 2.0
if squash_fn == SquashType.CLAMP:
# ctrl = torch.clamp(ctrl, action_lows[0], action_highs[0])
ctrl = torch.max(torch.min(ctrl, action_highs), action_lows)
return ctrl
elif squash_fn == SquashType.CLAMP_RESCALE:
ctrl = torch.clamp(ctrl, -1.0, 1.0)
elif squash_fn == SquashType.TANH:
ctrl = torch.tanh(ctrl)
elif squash_fn == SquashType.IDENTITY:
return ctrl
return act_mid_range.unsqueeze(0) + ctrl * act_half_range.unsqueeze(0)
#######################
## STOMP Covariance ##
#######################
@profiler.record_function("particle_opt_utils/get_stomp_cov")
def get_stomp_cov(
horizon: int,
d_action: int,
tensor_args=TensorDeviceType(),
cov_mode="acc",
RETURN_M=False,
):
"""Computes the covariance matrix following STOMP motion planner
Coefficients from here: https://en.wikipedia.org/wiki/Finite_difference_coefficient
More info here: https://github.com/ros-industrial/stomp_ros/blob/7fe40fbe6ad446459d8d4889916c64e276dbf882/stomp_core/src/utils.cpp#L36
"""
cov, scale_tril, scaled_M = get_stomp_cov_jit(horizon, d_action, cov_mode)
cov = tensor_args.to_device(cov)
scale_tril = tensor_args.to_device(scale_tril)
if RETURN_M:
return cov, scale_tril, tensor_args.to_device(scaled_M)
return cov, scale_tril
@torch.jit.script
def get_stomp_cov_jit(
horizon: int,
d_action: int,
cov_mode: str = "acc",
):
vel_fd_array = [0.0, 0.0, 1.0, -2.0, 1.0, 0.0, 0.0]
fd_array = vel_fd_array
A = torch.zeros(
(d_action * horizon, d_action * horizon),
dtype=torch.float64,
)
if cov_mode == "vel":
for k in range(d_action):
for i in range(0, horizon):
for j in range(-3, 4):
# print(j)
index = i + j
if index < 0:
index = 0
continue
if index >= horizon:
index = horizon - 1
continue
A[k * horizon + i, k * horizon + index] = fd_array[j + 3]
elif cov_mode == "acc":
for k in range(d_action):
for i in range(0, horizon):
for j in range(-3, 4):
index = i + j
if index < 0:
index = 0
continue
if index >= horizon:
index = horizon - 1
continue
if index >= horizon / 2:
A[k * horizon + i, k * horizon - index - horizon // 2 - 1] = fd_array[j + 3]
else:
A[k * horizon + i, k * horizon + index] = fd_array[j + 3]
R = torch.matmul(A.transpose(-2, -1), A)
M = torch.inverse(R)
scaled_M = (1 / horizon) * M / (torch.max(torch.abs(M), dim=1)[0].unsqueeze(0))
cov = M / torch.max(torch.abs(M))
# also compute the cholesky decomposition:
# scale_tril = torch.zeros((d_action * horizon, d_action * horizon), **tensor_args)
scale_tril = torch.linalg.cholesky(cov)
"""
k = 0
act_cov_matrix = cov[k * horizon:k * horizon + horizon, k * horizon:k * horizon + horizon]
print(act_cov_matrix.shape)
print(torch.det(act_cov_matrix))
local_cholesky = matrix_cholesky(act_cov_matrix)
for k in range(d_action):
scale_tril[k * horizon:k * horizon + horizon,k * horizon:k * horizon + horizon] = local_cholesky
"""
return cov, scale_tril, scaled_M
########################
## Gaussian Utilities ##
########################
def gaussian_logprob(mean, cov, x, cov_type="full"):
"""
Calculate gaussian log prob for given input batch x
Parameters
----------
mean (np.ndarray): [N x num_samples] batch of means
cov (np.ndarray): [N x N] covariance matrix
x (np.ndarray): [N x num_samples] batch of sample values
Returns
--------
log_prob (np.ndarray): [num_sampls] log probability of each sample
"""
N = cov.shape[0]
if cov_type == "diagonal":
cov_diag = cov.diagonal()
cov_inv = np.diag(1.0 / cov_diag)
cov_logdet = np.sum(np.log(cov_diag))
else:
cov_logdet = np.log(np.linalg.det(cov))
cov_inv = np.linalg.inv(cov)
diff = (x - mean).T
mahalanobis_dist = -0.5 * np.sum((diff @ cov_inv) * diff, axis=1)
const1 = -0.5 * N * np.log(2.0 * np.pi)
const2 = -0.5 * cov_logdet
log_prob = mahalanobis_dist + const1 + const2
return log_prob
def gaussian_logprobgrad(mean, cov, x, cov_type="full"):
if cov_type == "diagonal":
cov_inv = np.diag(1.0 / cov.diagonal())
else:
cov_inv = np.linalg.inv(cov)
diff = (x - mean).T
grad = diff @ cov_inv
return grad
def gaussian_entropy(cov=None, L=None): # , cov_type="full"):
"""
Entropy of multivariate gaussian given either covariance
or cholesky decomposition of covariance
"""
if cov is not None:
inp_device = cov.device
cov_logdet = torch.log(torch.det(cov))
# print(np.linalg.det(cov.cpu().numpy()))
# print(torch.det(cov))
N = cov.shape[0]
else:
inp_device = L.device
cov_logdet = 2.0 * torch.sum(torch.log(torch.diagonal(L)))
N = L.shape[0]
# if cov_type == "diagonal":
# cov_logdet = np.sum(np.log(cov.diagonal()))
# else:
# cov_logdet = np.log(np.linalg.det(cov))
term1 = 0.5 * cov_logdet
# pi = torch.tensor([math.pi], device=inp_device)
# pre-calculate 1.0 + torch.log(2.0*pi) = 2.837877066
term2 = 0.5 * N * 2.837877066
ent = term1 + term2
return ent.to(inp_device)
def gaussian_kl(mean0, cov0, mean1, cov1, cov_type="full"):
"""
KL-divergence between Gaussians given mean and covariance
KL(p||q) = E_{p}[log(p) - log(q)]
"""
N = cov0.shape[0]
if cov_type == "diagonal":
cov1_diag = cov1.diagonal()
cov1_inv = np.diag(1.0 / cov1_diag)
cov0_logdet = np.sum(np.log(cov0.diagonal()))
cov1_logdet = np.sum(np.log(cov1_diag))
else:
cov1_inv = np.linalg.inv(cov1)
cov0_logdet = np.log(np.linalg.det(cov0))
cov1_logdet = np.log(np.linalg.det(cov1))
term1 = 0.5 * np.trace(cov1_inv @ cov0)
diff = (mean1 - mean0).T
mahalanobis_dist = 0.5 * np.sum((diff @ cov1_inv) * diff, axis=1)
term3 = 0.5 * (-1.0 * N + cov1_logdet - cov0_logdet)
return term1 + mahalanobis_dist + term3
# @torch.jit.script
def cost_to_go(cost_seq, gamma_seq, only_first=False):
# type: (Tensor, Tensor, bool) -> Tensor
"""
Calculate (discounted) cost to go for given cost sequence
"""
# if torch.any(gamma_seq == 0):
# return cost_seq
cost_seq = gamma_seq * cost_seq # discounted cost sequence
if only_first:
cost_seq = torch.sum(cost_seq, dim=-1, keepdim=True) / gamma_seq[..., 0]
else:
# cost_seq = torch.cumsum(cost_seq[:, ::-1], axis=-1)[:, ::-1] # cost to go (but scaled by [1 , gamma, gamma*2 and so on])
cost_seq = torch.fliplr(
torch.cumsum(torch.fliplr(cost_seq), dim=-1)
) # cost to go (but scaled by [1 , gamma, gamma*2 and so on])
cost_seq /= gamma_seq # un-scale it to get true discounted cost to go
return cost_seq
def cost_to_go_np(cost_seq, gamma_seq):
"""
Calculate (discounted) cost to go for given cost sequence
"""
# if np.any(gamma_seq == 0):
# return cost_seq
cost_seq = gamma_seq * cost_seq # discounted reward sequence
cost_seq = np.cumsum(cost_seq[:, ::-1], axis=-1)[
:, ::-1
] # cost to go (but scaled by [1 , gamma, gamma*2 and so on])
cost_seq /= gamma_seq # un-scale it to get true discounted cost to go
return cost_seq
############
##Cholesky##
############
def matrix_cholesky(A):
L = torch.zeros_like(A)
for i in range(A.shape[-1]):
for j in range(i + 1):
s = 0.0
for k in range(j):
s = s + L[i, k] * L[j, k]
L[i, j] = torch.sqrt(A[i, i] - s) if (i == j) else (1.0 / L[j, j] * (A[i, j] - s))
return L
# Batched Cholesky decomp
def batch_cholesky(A):
L = torch.zeros_like(A)
for i in range(A.shape[-1]):
for j in range(i + 1):
s = 0.0
for k in range(j):
s = s + L[..., i, k] * L[..., j, k]
L[..., i, j] = (
torch.sqrt(A[..., i, i] - s)
if (i == j)
else (1.0 / L[..., j, j] * (A[..., i, j] - s))
)
return L