update to 0.6.2

This commit is contained in:
Balakumar Sundaralingam
2023-12-15 02:01:33 -08:00
parent d85ae41fba
commit 58958bbcce
105 changed files with 2514 additions and 934 deletions

View File

@@ -24,9 +24,16 @@ from curobo.util.logger import log_warn
# kernel for l-bfgs:
@torch.jit.script
# @torch.jit.script
def compute_step_direction(
alpha_buffer, rho_buffer, y_buffer, s_buffer, grad_q, m: int, epsilon, stable_mode: bool = True
alpha_buffer,
rho_buffer,
y_buffer,
s_buffer,
grad_q,
m: int,
epsilon: float,
stable_mode: bool = True,
):
# m = 15 (int)
# y_buffer, s_buffer: m x b x 175
@@ -70,12 +77,12 @@ class LBFGSOpt(NewtonOptBase, LBFGSOptConfig):
if config is not None:
LBFGSOptConfig.__init__(self, **vars(config))
NewtonOptBase.__init__(self)
if self.d_opt >= 1024 or self.history >= 512:
log_warn("LBFGS: Not using LBFGS Cuda Kernel as d_opt>1024 or history>=512")
if self.d_opt >= 1024 or self.history > 15:
log_warn("LBFGS: Not using LBFGS Cuda Kernel as d_opt>1024 or history>15")
self.use_cuda_kernel = False
if self.history > self.d_opt:
if self.history >= self.d_opt:
log_warn("LBFGS: history >= d_opt, reducing history to d_opt-1")
self.history = self.d_opt
self.history = self.d_opt - 1
@profiler.record_function("lbfgs/reset")
def reset(self):