constrained planning, robot segmentation
This commit is contained in:
@@ -54,7 +54,7 @@ class WrapBase(WrapConfig):
|
||||
def __init__(self, config: Optional[WrapConfig] = None):
|
||||
if config is not None:
|
||||
WrapConfig.__init__(self, **vars(config))
|
||||
self.n_envs = 1
|
||||
self.n_problems = 1
|
||||
self.opt_dt = 0
|
||||
self._rollout_list = None
|
||||
self._opt_rollouts = None
|
||||
@@ -83,11 +83,11 @@ class WrapBase(WrapConfig):
|
||||
debug_list.append(opt.debug_cost)
|
||||
return debug_list
|
||||
|
||||
def update_nenvs(self, n_envs):
|
||||
if n_envs != self.n_envs:
|
||||
self.n_envs = n_envs
|
||||
def update_nproblems(self, n_problems):
|
||||
if n_problems != self.n_problems:
|
||||
self.n_problems = n_problems
|
||||
for opt in self.optimizers:
|
||||
opt.update_nenvs(self.n_envs)
|
||||
opt.update_nproblems(self.n_problems)
|
||||
|
||||
def update_params(self, goal: Goal):
|
||||
with profiler.record_function("wrap_base/safety/update_params"):
|
||||
@@ -117,6 +117,12 @@ class WrapBase(WrapConfig):
|
||||
opt.reset_cuda_graph()
|
||||
self._init_solver = False
|
||||
|
||||
def reset_shape(self):
|
||||
self.safety_rollout.reset_shape()
|
||||
for opt in self.optimizers:
|
||||
opt.reset_shape()
|
||||
self._init_solver = False
|
||||
|
||||
@property
|
||||
def rollout_fn(self):
|
||||
return self.safety_rollout
|
||||
|
||||
Reference in New Issue
Block a user