Files
gen_data_curobo/src/curobo/util/torch_utils.py
Balakumar Sundaralingam 0c51dd2da8 improved joint space planning
2024-05-30 14:42:22 -07:00

170 lines
5.2 KiB
Python

#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
import os
from functools import lru_cache
from typing import Optional
# Third Party
import torch
from packaging import version
# CuRobo
from curobo.util.logger import log_info, log_warn
def find_first_idx(array, value, EQUAL=False):
if EQUAL:
f_idx = torch.nonzero(array >= value, as_tuple=False)[0].item()
else:
f_idx = torch.nonzero(array > value, as_tuple=False)[0].item()
return f_idx
def find_last_idx(array, value):
f_idx = torch.nonzero(array <= value, as_tuple=False)[-1].item()
return f_idx
def is_cuda_graph_available():
if version.parse(torch.__version__) < version.parse("1.10"):
log_warn("Disabling CUDA Graph as pytorch < 1.10")
return False
return True
def is_torch_compile_available():
force_compile = os.environ.get("CUROBO_TORCH_COMPILE_FORCE")
if force_compile is not None and bool(int(force_compile)):
return True
if version.parse(torch.__version__) < version.parse("2.0"):
log_info("Disabling torch.compile as pytorch < 2.0")
return False
env_variable = os.environ.get("CUROBO_TORCH_COMPILE_DISABLE")
if env_variable is None:
log_info("Environment variable for CUROBO_TORCH_COMPILE is not set, Disabling.")
return False
if bool(int(env_variable)):
log_info("Environment variable for CUROBO_TORCH_COMPILE is set to Disable")
return False
log_info("Environment variable for CUROBO_TORCH_COMPILE is set to Enable")
try:
torch.compile
except:
log_info("Could not find torch.compile, disabling Torch Compile.")
return False
try:
torch._dynamo
except:
log_info("Could not find torch._dynamo, disabling Torch Compile.")
return False
try:
# Third Party
import triton
except:
log_info("Could not find triton, disabling Torch Compile.")
return False
return True
def get_torch_compile_options() -> dict:
options = {}
if is_torch_compile_available():
# Third Party
from torch._inductor import config
torch._dynamo.config.suppress_errors = True
use_options = {
"max_autotune": True,
"use_mixed_mm": True,
"conv_1x1_as_mm": True,
"coordinate_descent_tuning": True,
"epilogue_fusion": False,
"coordinate_descent_check_all_directions": True,
"force_fuse_int_mm_with_mul": True,
"triton.cudagraphs": False,
"aggressive_fusion": True,
"split_reductions": False,
"worker_start_method": "spawn",
}
for k in use_options.keys():
if hasattr(config, k):
options[k] = use_options[k]
else:
log_info("Not found in torch.compile: " + k)
return options
def disable_torch_compile_global():
if is_torch_compile_available():
torch._dynamo.config.disable = True
return True
return False
def set_torch_compile_global_options():
if is_torch_compile_available():
# Third Party
from torch._inductor import config
torch._dynamo.config.suppress_errors = True
if hasattr(config, "conv_1x1_as_mm"):
torch._inductor.config.conv_1x1_as_mm = True
if hasattr(config, "coordinate_descent_tuning"):
torch._inductor.config.coordinate_descent_tuning = True
if hasattr(config, "epilogue_fusion"):
torch._inductor.config.epilogue_fusion = False
if hasattr(config, "coordinate_descent_check_all_directions"):
torch._inductor.config.coordinate_descent_check_all_directions = True
if hasattr(config, "force_fuse_int_mm_with_mul"):
torch._inductor.config.force_fuse_int_mm_with_mul = True
if hasattr(config, "use_mixed_mm"):
torch._inductor.config.use_mixed_mm = True
return True
return False
def get_torch_jit_decorator(
force_jit: bool = False, dynamic: bool = True, only_valid_for_compile: bool = False
):
if not force_jit and is_torch_compile_available():
return torch.compile(options=get_torch_compile_options(), dynamic=dynamic)
elif not only_valid_for_compile:
return torch.jit.script
else:
return empty_decorator
def is_lru_cache_avaiable():
use_lru_cache = os.environ.get("CUROBO_USE_LRU_CACHE")
if use_lru_cache is not None:
return bool(int(use_lru_cache))
log_info("Environment variable for CUROBO_USE_LRU_CACHE is not set, Enabling as default.")
return False
def get_cache_fn_decorator(maxsize: Optional[int] = None):
if is_lru_cache_avaiable():
return lru_cache(maxsize=maxsize)
else:
return empty_decorator
def empty_decorator(function):
return function