223 lines
7.6 KiB
Python
223 lines
7.6 KiB
Python
#
|
|
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
# property and proprietary rights in and to this material, related
|
|
# documentation and any modifications thereto. Any use, reproduction,
|
|
# disclosure or distribution of this material and related documentation
|
|
# without an express license agreement from NVIDIA CORPORATION or
|
|
# its affiliates is strictly prohibited.
|
|
#
|
|
|
|
# Standard Library
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
# Third Party
|
|
import torch
|
|
import torch.autograd.profiler as profiler
|
|
|
|
# CuRobo
|
|
from curobo.curobolib.opt import LBFGScu
|
|
from curobo.opt.newton.newton_base import NewtonOptBase, NewtonOptConfig
|
|
from curobo.util.logger import log_warn
|
|
|
|
|
|
# kernel for l-bfgs:
|
|
# @torch.jit.script
|
|
def compute_step_direction(
|
|
alpha_buffer,
|
|
rho_buffer,
|
|
y_buffer,
|
|
s_buffer,
|
|
grad_q,
|
|
m: int,
|
|
epsilon: float,
|
|
stable_mode: bool = True,
|
|
):
|
|
# m = 15 (int)
|
|
# y_buffer, s_buffer: m x b x 175
|
|
# q, grad_q: b x 175
|
|
# rho_buffer: m x b x 1
|
|
# alpha_buffer: m x b x 1 # this can be dynamically created
|
|
gq = grad_q.detach().clone()
|
|
rho_s = rho_buffer * s_buffer.transpose(-1, -2) # batched m_scalar-m_vector product
|
|
for i in range(m - 1, -1, -1):
|
|
alpha_buffer[i] = rho_s[i] @ gq # batched vector-vector dot product
|
|
gq = gq - (alpha_buffer[i] * y_buffer[i]) # batched scalar-vector product
|
|
|
|
var1 = (s_buffer[-1].transpose(-1, -2) @ y_buffer[-1]) / (
|
|
y_buffer[-1].transpose(-1, -2) @ y_buffer[-1]
|
|
)
|
|
if stable_mode:
|
|
var1 = torch.nan_to_num(var1, epsilon, epsilon, epsilon)
|
|
gamma = torch.nn.functional.relu(var1) # + epsilon
|
|
r = gamma * gq.clone()
|
|
rho_y = rho_buffer * y_buffer.transpose(-1, -2) # batched m_scalar-m_vector product
|
|
for i in range(m):
|
|
# batched dot product, scalar-vector product
|
|
r = r + (alpha_buffer[i] - (rho_y[i] @ r)) * s_buffer[i]
|
|
return -1.0 * r
|
|
|
|
|
|
@dataclass
|
|
class LBFGSOptConfig(NewtonOptConfig):
|
|
history: int = 10
|
|
epsilon: float = 0.01
|
|
use_cuda_kernel: bool = True
|
|
stable_mode: bool = True
|
|
|
|
def __post_init__(self):
|
|
return super().__post_init__()
|
|
|
|
|
|
class LBFGSOpt(NewtonOptBase, LBFGSOptConfig):
|
|
@profiler.record_function("lbfgs_opt/init")
|
|
def __init__(self, config: Optional[LBFGSOptConfig] = None):
|
|
if config is not None:
|
|
LBFGSOptConfig.__init__(self, **vars(config))
|
|
NewtonOptBase.__init__(self)
|
|
if self.d_opt >= 1024 or self.history > 15:
|
|
log_warn("LBFGS: Not using LBFGS Cuda Kernel as d_opt>1024 or history>15")
|
|
self.use_cuda_kernel = False
|
|
if self.history >= self.d_opt:
|
|
log_warn("LBFGS: history >= d_opt, reducing history to d_opt-1")
|
|
self.history = self.d_opt - 1
|
|
|
|
@profiler.record_function("lbfgs/reset")
|
|
def reset(self):
|
|
self.x_0[:] = 0.0
|
|
self.grad_0[:] = 0.0
|
|
self.s_buffer[:] = 0.0
|
|
self.y_buffer[:] = 0.0
|
|
self.rho_buffer[:] = 0.0
|
|
self.alpha_buffer[:] = 0.0
|
|
self.step_q_buffer[:] = 0.0
|
|
return super().reset()
|
|
|
|
def update_nenvs(self, n_envs):
|
|
self.init_hessian(b=n_envs)
|
|
return super().update_nenvs(n_envs)
|
|
|
|
def init_hessian(self, b=1):
|
|
self.x_0 = torch.zeros(
|
|
(b, self.d_opt, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
|
)
|
|
self.grad_0 = torch.zeros(
|
|
(b, self.d_opt, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
|
)
|
|
self.y_buffer = torch.zeros(
|
|
(self.history, b, self.d_opt, 1),
|
|
device=self.tensor_args.device,
|
|
dtype=self.tensor_args.dtype,
|
|
) # + 1.0
|
|
self.s_buffer = torch.zeros(
|
|
(self.history, b, self.d_opt, 1),
|
|
device=self.tensor_args.device,
|
|
dtype=self.tensor_args.dtype,
|
|
) # + 1.0
|
|
self.rho_buffer = torch.zeros(
|
|
(self.history, b, 1, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
|
) # + 1.0
|
|
self.step_q_buffer = torch.zeros(
|
|
(b, self.d_opt), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
|
)
|
|
self.alpha_buffer = torch.zeros(
|
|
(self.history, b, 1, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
|
) # + 1.0
|
|
|
|
@torch.no_grad()
|
|
def _get_step_direction(self, cost, q, grad_q):
|
|
if self.use_cuda_kernel:
|
|
with profiler.record_function("lbfgs/fused"):
|
|
dq = LBFGScu._call_cuda(
|
|
self.step_q_buffer,
|
|
self.rho_buffer,
|
|
self.y_buffer,
|
|
self.s_buffer,
|
|
q,
|
|
grad_q,
|
|
self.x_0,
|
|
self.grad_0,
|
|
self.epsilon,
|
|
self.stable_mode,
|
|
)
|
|
|
|
else:
|
|
grad_q = grad_q.transpose(-1, -2)
|
|
q = q.unsqueeze(-1)
|
|
self._update_buffers(q, grad_q)
|
|
dq = compute_step_direction(
|
|
self.alpha_buffer,
|
|
self.rho_buffer,
|
|
self.y_buffer,
|
|
self.s_buffer,
|
|
grad_q,
|
|
self.history,
|
|
self.epsilon,
|
|
self.stable_mode,
|
|
)
|
|
|
|
dq = dq.view(-1, self.d_opt)
|
|
return dq
|
|
|
|
def _update_q(self, grad_q):
|
|
q = grad_q.detach().clone()
|
|
rho_s = self.rho_buffer * self.s_buffer.transpose(-1, -2)
|
|
for i in range(self.history - 1, -1, -1):
|
|
self.alpha_buffer[i] = rho_s[i] @ q
|
|
q = q - (self.alpha_buffer[i] * self.y_buffer[i])
|
|
return q
|
|
|
|
def _update_r(self, r_init):
|
|
r = r_init.clone()
|
|
rho_y = self.rho_buffer * self.y_buffer.transpose(-1, -2)
|
|
for i in range(self.history):
|
|
r = r + self.s_buffer[i] * (self.alpha_buffer[i] - rho_y[i] @ r)
|
|
return -1.0 * r
|
|
|
|
def _update_buffers(self, q, grad_q):
|
|
y = grad_q - self.grad_0
|
|
s = q - self.x_0
|
|
rho = 1 / (y.transpose(-1, -2) @ s)
|
|
if self.stable_mode:
|
|
rho = torch.nan_to_num(rho, 0, 0, 0)
|
|
self.s_buffer[0] = s
|
|
self.s_buffer[:] = torch.roll(self.s_buffer, -1, dims=0)
|
|
self.y_buffer[0] = y
|
|
self.y_buffer[:] = torch.roll(self.y_buffer, -1, dims=0) # .copy_(y_buff)
|
|
|
|
self.rho_buffer[0] = rho
|
|
|
|
self.rho_buffer[:] = torch.roll(self.rho_buffer, -1, dims=0)
|
|
|
|
self.x_0.copy_(q)
|
|
self.grad_0.copy_(grad_q)
|
|
|
|
def _shift(self, shift_steps=0):
|
|
"""Shift the optimizer by shift_steps * d_opt
|
|
|
|
Args:
|
|
shift_steps (int, optional): _description_. Defaults to 0.
|
|
"""
|
|
if shift_steps == 0:
|
|
return
|
|
self.reset()
|
|
shift_d = shift_steps * self.d_action
|
|
self.x_0 = self._shift_buffer(self.x_0, shift_d, shift_steps)
|
|
self.grad_0 = self._shift_buffer(self.grad_0, shift_d, shift_steps)
|
|
self.y_buffer = self._shift_buffer(self.y_buffer, shift_d, shift_steps)
|
|
self.s_buffer = self._shift_buffer(self.s_buffer, shift_d, shift_steps)
|
|
super()._shift(shift_steps=shift_steps)
|
|
|
|
def _shift_buffer(self, buffer, shift_d, shift_steps: int = 1):
|
|
buffer = buffer.roll(-shift_d, -2)
|
|
end_value = -(shift_steps - 1) * self.d_action
|
|
if end_value == 0:
|
|
end_value = buffer.shape[-2]
|
|
buffer[..., -shift_d:end_value, :] = buffer[
|
|
..., -shift_d - self.d_action : -shift_d, :
|
|
].clone()
|
|
|
|
return buffer
|