update to 0.6.2
This commit is contained in:
@@ -24,9 +24,16 @@ from curobo.util.logger import log_warn
|
||||
|
||||
|
||||
# kernel for l-bfgs:
|
||||
@torch.jit.script
|
||||
# @torch.jit.script
|
||||
def compute_step_direction(
|
||||
alpha_buffer, rho_buffer, y_buffer, s_buffer, grad_q, m: int, epsilon, stable_mode: bool = True
|
||||
alpha_buffer,
|
||||
rho_buffer,
|
||||
y_buffer,
|
||||
s_buffer,
|
||||
grad_q,
|
||||
m: int,
|
||||
epsilon: float,
|
||||
stable_mode: bool = True,
|
||||
):
|
||||
# m = 15 (int)
|
||||
# y_buffer, s_buffer: m x b x 175
|
||||
@@ -70,12 +77,12 @@ class LBFGSOpt(NewtonOptBase, LBFGSOptConfig):
|
||||
if config is not None:
|
||||
LBFGSOptConfig.__init__(self, **vars(config))
|
||||
NewtonOptBase.__init__(self)
|
||||
if self.d_opt >= 1024 or self.history >= 512:
|
||||
log_warn("LBFGS: Not using LBFGS Cuda Kernel as d_opt>1024 or history>=512")
|
||||
if self.d_opt >= 1024 or self.history > 15:
|
||||
log_warn("LBFGS: Not using LBFGS Cuda Kernel as d_opt>1024 or history>15")
|
||||
self.use_cuda_kernel = False
|
||||
if self.history > self.d_opt:
|
||||
if self.history >= self.d_opt:
|
||||
log_warn("LBFGS: history >= d_opt, reducing history to d_opt-1")
|
||||
self.history = self.d_opt
|
||||
self.history = self.d_opt - 1
|
||||
|
||||
@profiler.record_function("lbfgs/reset")
|
||||
def reset(self):
|
||||
|
||||
Reference in New Issue
Block a user