Significantly improved convergence for mesh and cuboid, new ESDF collision.
This commit is contained in:
@@ -8,12 +8,15 @@
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
#
|
||||
# Standard Library
|
||||
import os
|
||||
|
||||
# Third Party
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
# CuRobo
|
||||
from curobo.util.logger import log_warn
|
||||
from curobo.util.logger import log_info, log_warn
|
||||
|
||||
|
||||
def find_first_idx(array, value, EQUAL=False):
|
||||
@@ -31,13 +34,119 @@ def find_last_idx(array, value):
|
||||
|
||||
def is_cuda_graph_available():
|
||||
if version.parse(torch.__version__) < version.parse("1.10"):
|
||||
log_warn("WARNING: Disabling CUDA Graph as pytorch < 1.10")
|
||||
log_warn("Disabling CUDA Graph as pytorch < 1.10")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_torch_compile_available():
|
||||
force_compile = os.environ.get("CUROBO_TORCH_COMPILE_FORCE")
|
||||
if force_compile is not None and bool(int(force_compile)):
|
||||
return True
|
||||
if version.parse(torch.__version__) < version.parse("2.0"):
|
||||
log_warn("WARNING: Disabling compile as pytorch < 2.0")
|
||||
log_info("Disabling torch.compile as pytorch < 2.0")
|
||||
return False
|
||||
|
||||
env_variable = os.environ.get("CUROBO_TORCH_COMPILE_DISABLE")
|
||||
|
||||
if env_variable is None:
|
||||
log_info("Environment variable for CUROBO_TORCH_COMPILE is not set, Disabling.")
|
||||
|
||||
return False
|
||||
|
||||
if bool(int(env_variable)):
|
||||
log_info("Environment variable for CUROBO_TORCH_COMPILE is set to Disable")
|
||||
return False
|
||||
|
||||
log_info("Environment variable for CUROBO_TORCH_COMPILE is set to Enable")
|
||||
|
||||
try:
|
||||
torch.compile
|
||||
except:
|
||||
log_info("Could not find torch.compile, disabling Torch Compile.")
|
||||
return False
|
||||
try:
|
||||
torch._dynamo
|
||||
except:
|
||||
log_info("Could not find torch._dynamo, disabling Torch Compile.")
|
||||
return False
|
||||
try:
|
||||
# Third Party
|
||||
import triton
|
||||
except:
|
||||
log_info("Could not find triton, disabling Torch Compile.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_torch_compile_options() -> dict:
|
||||
options = {}
|
||||
if is_torch_compile_available():
|
||||
# Third Party
|
||||
from torch._inductor import config
|
||||
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
use_options = {
|
||||
"max_autotune": True,
|
||||
"use_mixed_mm": True,
|
||||
"conv_1x1_as_mm": True,
|
||||
"coordinate_descent_tuning": True,
|
||||
"epilogue_fusion": False,
|
||||
"coordinate_descent_check_all_directions": True,
|
||||
"force_fuse_int_mm_with_mul": True,
|
||||
"triton.cudagraphs": False,
|
||||
"aggressive_fusion": True,
|
||||
"split_reductions": False,
|
||||
"worker_start_method": "spawn",
|
||||
}
|
||||
for k in use_options.keys():
|
||||
if hasattr(config, k):
|
||||
options[k] = use_options[k]
|
||||
else:
|
||||
log_info("Not found in torch.compile: " + k)
|
||||
return options
|
||||
|
||||
|
||||
def disable_torch_compile_global():
|
||||
if is_torch_compile_available():
|
||||
torch._dynamo.config.disable = True
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def set_torch_compile_global_options():
|
||||
if is_torch_compile_available():
|
||||
# Third Party
|
||||
from torch._inductor import config
|
||||
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
if hasattr(config, "conv_1x1_as_mm"):
|
||||
torch._inductor.config.conv_1x1_as_mm = True
|
||||
if hasattr(config, "coordinate_descent_tuning"):
|
||||
torch._inductor.config.coordinate_descent_tuning = True
|
||||
if hasattr(config, "epilogue_fusion"):
|
||||
torch._inductor.config.epilogue_fusion = False
|
||||
if hasattr(config, "coordinate_descent_check_all_directions"):
|
||||
torch._inductor.config.coordinate_descent_check_all_directions = True
|
||||
if hasattr(config, "force_fuse_int_mm_with_mul"):
|
||||
torch._inductor.config.force_fuse_int_mm_with_mul = True
|
||||
if hasattr(config, "use_mixed_mm"):
|
||||
torch._inductor.config.use_mixed_mm = True
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_torch_jit_decorator(
|
||||
force_jit: bool = False, dynamic: bool = True, only_valid_for_compile: bool = False
|
||||
):
|
||||
if not force_jit and is_torch_compile_available():
|
||||
return torch.compile(options=get_torch_compile_options(), dynamic=dynamic)
|
||||
elif not only_valid_for_compile:
|
||||
return torch.jit.script
|
||||
else:
|
||||
return empty_decorator
|
||||
|
||||
|
||||
def empty_decorator(function):
|
||||
return function
|
||||
|
||||
Reference in New Issue
Block a user