release repository
This commit is contained in:
215
src/curobo/opt/newton/lbfgs.py
Normal file
215
src/curobo/opt/newton/lbfgs.py
Normal file
@@ -0,0 +1,215 @@
|
||||
#
|
||||
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
#
|
||||
|
||||
# Standard Library
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
# Third Party
|
||||
import torch
|
||||
import torch.autograd.profiler as profiler
|
||||
|
||||
# CuRobo
|
||||
from curobo.curobolib.opt import LBFGScu
|
||||
from curobo.opt.newton.newton_base import NewtonOptBase, NewtonOptConfig
|
||||
from curobo.util.logger import log_warn
|
||||
|
||||
|
||||
# kernel for l-bfgs:
|
||||
@torch.jit.script
|
||||
def compute_step_direction(
|
||||
alpha_buffer, rho_buffer, y_buffer, s_buffer, grad_q, m: int, epsilon, stable_mode: bool = True
|
||||
):
|
||||
# m = 15 (int)
|
||||
# y_buffer, s_buffer: m x b x 175
|
||||
# q, grad_q: b x 175
|
||||
# rho_buffer: m x b x 1
|
||||
# alpha_buffer: m x b x 1 # this can be dynamically created
|
||||
gq = grad_q.detach().clone()
|
||||
rho_s = rho_buffer * s_buffer.transpose(-1, -2) # batched m_scalar-m_vector product
|
||||
for i in range(m - 1, -1, -1):
|
||||
alpha_buffer[i] = rho_s[i] @ gq # batched vector-vector dot product
|
||||
gq = gq - (alpha_buffer[i] * y_buffer[i]) # batched scalar-vector product
|
||||
|
||||
var1 = (s_buffer[-1].transpose(-1, -2) @ y_buffer[-1]) / (
|
||||
y_buffer[-1].transpose(-1, -2) @ y_buffer[-1]
|
||||
)
|
||||
if stable_mode:
|
||||
var1 = torch.nan_to_num(var1, epsilon, epsilon, epsilon)
|
||||
gamma = torch.nn.functional.relu(var1) # + epsilon
|
||||
r = gamma * gq.clone()
|
||||
rho_y = rho_buffer * y_buffer.transpose(-1, -2) # batched m_scalar-m_vector product
|
||||
for i in range(m):
|
||||
# batched dot product, scalar-vector product
|
||||
r = r + (alpha_buffer[i] - (rho_y[i] @ r)) * s_buffer[i]
|
||||
return -1.0 * r
|
||||
|
||||
|
||||
@dataclass
|
||||
class LBFGSOptConfig(NewtonOptConfig):
|
||||
history: int = 10
|
||||
epsilon: float = 0.01
|
||||
use_cuda_kernel: bool = True
|
||||
stable_mode: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
return super().__post_init__()
|
||||
|
||||
|
||||
class LBFGSOpt(NewtonOptBase, LBFGSOptConfig):
|
||||
@profiler.record_function("lbfgs_opt/init")
|
||||
def __init__(self, config: Optional[LBFGSOptConfig] = None):
|
||||
if config is not None:
|
||||
LBFGSOptConfig.__init__(self, **vars(config))
|
||||
NewtonOptBase.__init__(self)
|
||||
if self.d_opt >= 1024 or self.history >= 512:
|
||||
log_warn("LBFGS: Not using LBFGS Cuda Kernel as d_opt>1024 or history>=512")
|
||||
self.use_cuda_kernel = False
|
||||
if self.history > self.d_opt:
|
||||
log_warn("LBFGS: history >= d_opt, reducing history to d_opt-1")
|
||||
self.history = self.d_opt
|
||||
|
||||
@profiler.record_function("lbfgs/reset")
|
||||
def reset(self):
|
||||
self.x_0[:] = 0.0
|
||||
self.grad_0[:] = 0.0
|
||||
self.s_buffer[:] = 0.0
|
||||
self.y_buffer[:] = 0.0
|
||||
self.rho_buffer[:] = 0.0
|
||||
self.alpha_buffer[:] = 0.0
|
||||
self.step_q_buffer[:] = 0.0
|
||||
return super().reset()
|
||||
|
||||
def update_nenvs(self, n_envs):
|
||||
self.init_hessian(b=n_envs)
|
||||
return super().update_nenvs(n_envs)
|
||||
|
||||
def init_hessian(self, b=1):
|
||||
self.x_0 = torch.zeros(
|
||||
(b, self.d_opt, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
)
|
||||
self.grad_0 = torch.zeros(
|
||||
(b, self.d_opt, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
)
|
||||
self.y_buffer = torch.zeros(
|
||||
(self.history, b, self.d_opt, 1),
|
||||
device=self.tensor_args.device,
|
||||
dtype=self.tensor_args.dtype,
|
||||
) # + 1.0
|
||||
self.s_buffer = torch.zeros(
|
||||
(self.history, b, self.d_opt, 1),
|
||||
device=self.tensor_args.device,
|
||||
dtype=self.tensor_args.dtype,
|
||||
) # + 1.0
|
||||
self.rho_buffer = torch.zeros(
|
||||
(self.history, b, 1, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
) # + 1.0
|
||||
self.step_q_buffer = torch.zeros(
|
||||
(b, self.d_opt), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
)
|
||||
self.alpha_buffer = torch.zeros(
|
||||
(self.history, b, 1, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
|
||||
) # + 1.0
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_step_direction(self, cost, q, grad_q):
|
||||
if self.use_cuda_kernel:
|
||||
with profiler.record_function("lbfgs/fused"):
|
||||
dq = LBFGScu._call_cuda(
|
||||
self.step_q_buffer,
|
||||
self.rho_buffer,
|
||||
self.y_buffer,
|
||||
self.s_buffer,
|
||||
q,
|
||||
grad_q,
|
||||
self.x_0,
|
||||
self.grad_0,
|
||||
self.epsilon,
|
||||
self.stable_mode,
|
||||
)
|
||||
|
||||
else:
|
||||
grad_q = grad_q.transpose(-1, -2)
|
||||
q = q.unsqueeze(-1)
|
||||
self._update_buffers(q, grad_q)
|
||||
dq = compute_step_direction(
|
||||
self.alpha_buffer,
|
||||
self.rho_buffer,
|
||||
self.y_buffer,
|
||||
self.s_buffer,
|
||||
grad_q,
|
||||
self.history,
|
||||
self.epsilon,
|
||||
self.stable_mode,
|
||||
)
|
||||
|
||||
dq = dq.view(-1, self.d_opt)
|
||||
return dq
|
||||
|
||||
def _update_q(self, grad_q):
|
||||
q = grad_q.detach().clone()
|
||||
rho_s = self.rho_buffer * self.s_buffer.transpose(-1, -2)
|
||||
for i in range(self.history - 1, -1, -1):
|
||||
self.alpha_buffer[i] = rho_s[i] @ q
|
||||
q = q - (self.alpha_buffer[i] * self.y_buffer[i])
|
||||
return q
|
||||
|
||||
def _update_r(self, r_init):
|
||||
r = r_init.clone()
|
||||
rho_y = self.rho_buffer * self.y_buffer.transpose(-1, -2)
|
||||
for i in range(self.history):
|
||||
r = r + self.s_buffer[i] * (self.alpha_buffer[i] - rho_y[i] @ r)
|
||||
return -1.0 * r
|
||||
|
||||
def _update_buffers(self, q, grad_q):
|
||||
y = grad_q - self.grad_0
|
||||
s = q - self.x_0
|
||||
rho = 1 / (y.transpose(-1, -2) @ s)
|
||||
if self.stable_mode:
|
||||
rho = torch.nan_to_num(rho, 0, 0, 0)
|
||||
self.s_buffer[0] = s
|
||||
self.s_buffer[:] = torch.roll(self.s_buffer, -1, dims=0)
|
||||
self.y_buffer[0] = y
|
||||
self.y_buffer[:] = torch.roll(self.y_buffer, -1, dims=0) # .copy_(y_buff)
|
||||
|
||||
self.rho_buffer[0] = rho
|
||||
|
||||
self.rho_buffer[:] = torch.roll(self.rho_buffer, -1, dims=0)
|
||||
|
||||
self.x_0.copy_(q)
|
||||
self.grad_0.copy_(grad_q)
|
||||
|
||||
def _shift(self, shift_steps=0):
|
||||
"""Shift the optimizer by shift_steps * d_opt
|
||||
|
||||
Args:
|
||||
shift_steps (int, optional): _description_. Defaults to 0.
|
||||
"""
|
||||
if shift_steps == 0:
|
||||
return
|
||||
self.reset()
|
||||
shift_d = shift_steps * self.d_action
|
||||
self.x_0 = self._shift_buffer(self.x_0, shift_d, shift_steps)
|
||||
self.grad_0 = self._shift_buffer(self.grad_0, shift_d, shift_steps)
|
||||
self.y_buffer = self._shift_buffer(self.y_buffer, shift_d, shift_steps)
|
||||
self.s_buffer = self._shift_buffer(self.s_buffer, shift_d, shift_steps)
|
||||
super()._shift(shift_steps=shift_steps)
|
||||
|
||||
def _shift_buffer(self, buffer, shift_d, shift_steps: int = 1):
|
||||
buffer = buffer.roll(-shift_d, -2)
|
||||
end_value = -(shift_steps - 1) * self.d_action
|
||||
if end_value == 0:
|
||||
end_value = buffer.shape[-2]
|
||||
buffer[..., -shift_d:end_value, :] = buffer[
|
||||
..., -shift_d - self.d_action : -shift_d, :
|
||||
].clone()
|
||||
|
||||
return buffer
|
||||
Reference in New Issue
Block a user