Significantly improved convergence for mesh and cuboid, new ESDF collision.

This commit is contained in:
Balakumar Sundaralingam
2024-03-18 11:19:48 -07:00
parent 286b3820a5
commit b1f63e8778
100 changed files with 7587 additions and 2589 deletions

View File

@@ -8,12 +8,15 @@
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
import os
# Third Party
import torch
from packaging import version
# CuRobo
from curobo.util.logger import log_warn
from curobo.util.logger import log_info, log_warn
def find_first_idx(array, value, EQUAL=False):
@@ -31,13 +34,119 @@ def find_last_idx(array, value):
def is_cuda_graph_available():
if version.parse(torch.__version__) < version.parse("1.10"):
log_warn("WARNING: Disabling CUDA Graph as pytorch < 1.10")
log_warn("Disabling CUDA Graph as pytorch < 1.10")
return False
return True
def is_torch_compile_available():
force_compile = os.environ.get("CUROBO_TORCH_COMPILE_FORCE")
if force_compile is not None and bool(int(force_compile)):
return True
if version.parse(torch.__version__) < version.parse("2.0"):
log_warn("WARNING: Disabling compile as pytorch < 2.0")
log_info("Disabling torch.compile as pytorch < 2.0")
return False
env_variable = os.environ.get("CUROBO_TORCH_COMPILE_DISABLE")
if env_variable is None:
log_info("Environment variable for CUROBO_TORCH_COMPILE is not set, Disabling.")
return False
if bool(int(env_variable)):
log_info("Environment variable for CUROBO_TORCH_COMPILE is set to Disable")
return False
log_info("Environment variable for CUROBO_TORCH_COMPILE is set to Enable")
try:
torch.compile
except:
log_info("Could not find torch.compile, disabling Torch Compile.")
return False
try:
torch._dynamo
except:
log_info("Could not find torch._dynamo, disabling Torch Compile.")
return False
try:
# Third Party
import triton
except:
log_info("Could not find triton, disabling Torch Compile.")
return False
return True
def get_torch_compile_options() -> dict:
options = {}
if is_torch_compile_available():
# Third Party
from torch._inductor import config
torch._dynamo.config.suppress_errors = True
use_options = {
"max_autotune": True,
"use_mixed_mm": True,
"conv_1x1_as_mm": True,
"coordinate_descent_tuning": True,
"epilogue_fusion": False,
"coordinate_descent_check_all_directions": True,
"force_fuse_int_mm_with_mul": True,
"triton.cudagraphs": False,
"aggressive_fusion": True,
"split_reductions": False,
"worker_start_method": "spawn",
}
for k in use_options.keys():
if hasattr(config, k):
options[k] = use_options[k]
else:
log_info("Not found in torch.compile: " + k)
return options
def disable_torch_compile_global():
if is_torch_compile_available():
torch._dynamo.config.disable = True
return True
return False
def set_torch_compile_global_options():
if is_torch_compile_available():
# Third Party
from torch._inductor import config
torch._dynamo.config.suppress_errors = True
if hasattr(config, "conv_1x1_as_mm"):
torch._inductor.config.conv_1x1_as_mm = True
if hasattr(config, "coordinate_descent_tuning"):
torch._inductor.config.coordinate_descent_tuning = True
if hasattr(config, "epilogue_fusion"):
torch._inductor.config.epilogue_fusion = False
if hasattr(config, "coordinate_descent_check_all_directions"):
torch._inductor.config.coordinate_descent_check_all_directions = True
if hasattr(config, "force_fuse_int_mm_with_mul"):
torch._inductor.config.force_fuse_int_mm_with_mul = True
if hasattr(config, "use_mixed_mm"):
torch._inductor.config.use_mixed_mm = True
return True
return False
def get_torch_jit_decorator(
force_jit: bool = False, dynamic: bool = True, only_valid_for_compile: bool = False
):
if not force_jit and is_torch_compile_available():
return torch.compile(options=get_torch_compile_options(), dynamic=dynamic)
elif not only_valid_for_compile:
return torch.jit.script
else:
return empty_decorator
def empty_decorator(function):
return function