Files
gen_data_curobo/src/curobo/opt/newton/lbfgs.py
Balakumar Sundaralingam 58958bbcce update to 0.6.2
2023-12-15 02:01:33 -08:00

223 lines
7.6 KiB
Python

#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from dataclasses import dataclass
from typing import Optional
# Third Party
import torch
import torch.autograd.profiler as profiler
# CuRobo
from curobo.curobolib.opt import LBFGScu
from curobo.opt.newton.newton_base import NewtonOptBase, NewtonOptConfig
from curobo.util.logger import log_warn
# kernel for l-bfgs:
# @torch.jit.script
def compute_step_direction(
alpha_buffer,
rho_buffer,
y_buffer,
s_buffer,
grad_q,
m: int,
epsilon: float,
stable_mode: bool = True,
):
# m = 15 (int)
# y_buffer, s_buffer: m x b x 175
# q, grad_q: b x 175
# rho_buffer: m x b x 1
# alpha_buffer: m x b x 1 # this can be dynamically created
gq = grad_q.detach().clone()
rho_s = rho_buffer * s_buffer.transpose(-1, -2) # batched m_scalar-m_vector product
for i in range(m - 1, -1, -1):
alpha_buffer[i] = rho_s[i] @ gq # batched vector-vector dot product
gq = gq - (alpha_buffer[i] * y_buffer[i]) # batched scalar-vector product
var1 = (s_buffer[-1].transpose(-1, -2) @ y_buffer[-1]) / (
y_buffer[-1].transpose(-1, -2) @ y_buffer[-1]
)
if stable_mode:
var1 = torch.nan_to_num(var1, epsilon, epsilon, epsilon)
gamma = torch.nn.functional.relu(var1) # + epsilon
r = gamma * gq.clone()
rho_y = rho_buffer * y_buffer.transpose(-1, -2) # batched m_scalar-m_vector product
for i in range(m):
# batched dot product, scalar-vector product
r = r + (alpha_buffer[i] - (rho_y[i] @ r)) * s_buffer[i]
return -1.0 * r
@dataclass
class LBFGSOptConfig(NewtonOptConfig):
history: int = 10
epsilon: float = 0.01
use_cuda_kernel: bool = True
stable_mode: bool = True
def __post_init__(self):
return super().__post_init__()
class LBFGSOpt(NewtonOptBase, LBFGSOptConfig):
@profiler.record_function("lbfgs_opt/init")
def __init__(self, config: Optional[LBFGSOptConfig] = None):
if config is not None:
LBFGSOptConfig.__init__(self, **vars(config))
NewtonOptBase.__init__(self)
if self.d_opt >= 1024 or self.history > 15:
log_warn("LBFGS: Not using LBFGS Cuda Kernel as d_opt>1024 or history>15")
self.use_cuda_kernel = False
if self.history >= self.d_opt:
log_warn("LBFGS: history >= d_opt, reducing history to d_opt-1")
self.history = self.d_opt - 1
@profiler.record_function("lbfgs/reset")
def reset(self):
self.x_0[:] = 0.0
self.grad_0[:] = 0.0
self.s_buffer[:] = 0.0
self.y_buffer[:] = 0.0
self.rho_buffer[:] = 0.0
self.alpha_buffer[:] = 0.0
self.step_q_buffer[:] = 0.0
return super().reset()
def update_nenvs(self, n_envs):
self.init_hessian(b=n_envs)
return super().update_nenvs(n_envs)
def init_hessian(self, b=1):
self.x_0 = torch.zeros(
(b, self.d_opt, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self.grad_0 = torch.zeros(
(b, self.d_opt, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self.y_buffer = torch.zeros(
(self.history, b, self.d_opt, 1),
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
) # + 1.0
self.s_buffer = torch.zeros(
(self.history, b, self.d_opt, 1),
device=self.tensor_args.device,
dtype=self.tensor_args.dtype,
) # + 1.0
self.rho_buffer = torch.zeros(
(self.history, b, 1, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
) # + 1.0
self.step_q_buffer = torch.zeros(
(b, self.d_opt), device=self.tensor_args.device, dtype=self.tensor_args.dtype
)
self.alpha_buffer = torch.zeros(
(self.history, b, 1, 1), device=self.tensor_args.device, dtype=self.tensor_args.dtype
) # + 1.0
@torch.no_grad()
def _get_step_direction(self, cost, q, grad_q):
if self.use_cuda_kernel:
with profiler.record_function("lbfgs/fused"):
dq = LBFGScu._call_cuda(
self.step_q_buffer,
self.rho_buffer,
self.y_buffer,
self.s_buffer,
q,
grad_q,
self.x_0,
self.grad_0,
self.epsilon,
self.stable_mode,
)
else:
grad_q = grad_q.transpose(-1, -2)
q = q.unsqueeze(-1)
self._update_buffers(q, grad_q)
dq = compute_step_direction(
self.alpha_buffer,
self.rho_buffer,
self.y_buffer,
self.s_buffer,
grad_q,
self.history,
self.epsilon,
self.stable_mode,
)
dq = dq.view(-1, self.d_opt)
return dq
def _update_q(self, grad_q):
q = grad_q.detach().clone()
rho_s = self.rho_buffer * self.s_buffer.transpose(-1, -2)
for i in range(self.history - 1, -1, -1):
self.alpha_buffer[i] = rho_s[i] @ q
q = q - (self.alpha_buffer[i] * self.y_buffer[i])
return q
def _update_r(self, r_init):
r = r_init.clone()
rho_y = self.rho_buffer * self.y_buffer.transpose(-1, -2)
for i in range(self.history):
r = r + self.s_buffer[i] * (self.alpha_buffer[i] - rho_y[i] @ r)
return -1.0 * r
def _update_buffers(self, q, grad_q):
y = grad_q - self.grad_0
s = q - self.x_0
rho = 1 / (y.transpose(-1, -2) @ s)
if self.stable_mode:
rho = torch.nan_to_num(rho, 0, 0, 0)
self.s_buffer[0] = s
self.s_buffer[:] = torch.roll(self.s_buffer, -1, dims=0)
self.y_buffer[0] = y
self.y_buffer[:] = torch.roll(self.y_buffer, -1, dims=0) # .copy_(y_buff)
self.rho_buffer[0] = rho
self.rho_buffer[:] = torch.roll(self.rho_buffer, -1, dims=0)
self.x_0.copy_(q)
self.grad_0.copy_(grad_q)
def _shift(self, shift_steps=0):
"""Shift the optimizer by shift_steps * d_opt
Args:
shift_steps (int, optional): _description_. Defaults to 0.
"""
if shift_steps == 0:
return
self.reset()
shift_d = shift_steps * self.d_action
self.x_0 = self._shift_buffer(self.x_0, shift_d, shift_steps)
self.grad_0 = self._shift_buffer(self.grad_0, shift_d, shift_steps)
self.y_buffer = self._shift_buffer(self.y_buffer, shift_d, shift_steps)
self.s_buffer = self._shift_buffer(self.s_buffer, shift_d, shift_steps)
super()._shift(shift_steps=shift_steps)
def _shift_buffer(self, buffer, shift_d, shift_steps: int = 1):
buffer = buffer.roll(-shift_d, -2)
end_value = -(shift_steps - 1) * self.d_action
if end_value == 0:
end_value = buffer.shape[-2]
buffer[..., -shift_d:end_value, :] = buffer[
..., -shift_d - self.d_action : -shift_d, :
].clone()
return buffer