170 lines
5.2 KiB
Python
170 lines
5.2 KiB
Python
#
|
|
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
# property and proprietary rights in and to this material, related
|
|
# documentation and any modifications thereto. Any use, reproduction,
|
|
# disclosure or distribution of this material and related documentation
|
|
# without an express license agreement from NVIDIA CORPORATION or
|
|
# its affiliates is strictly prohibited.
|
|
#
|
|
# Standard Library
|
|
import os
|
|
from functools import lru_cache
|
|
from typing import Optional
|
|
|
|
# Third Party
|
|
import torch
|
|
from packaging import version
|
|
|
|
# CuRobo
|
|
from curobo.util.logger import log_info, log_warn
|
|
|
|
|
|
def find_first_idx(array, value, EQUAL=False):
|
|
if EQUAL:
|
|
f_idx = torch.nonzero(array >= value, as_tuple=False)[0].item()
|
|
else:
|
|
f_idx = torch.nonzero(array > value, as_tuple=False)[0].item()
|
|
return f_idx
|
|
|
|
|
|
def find_last_idx(array, value):
|
|
f_idx = torch.nonzero(array <= value, as_tuple=False)[-1].item()
|
|
return f_idx
|
|
|
|
|
|
def is_cuda_graph_available():
|
|
if version.parse(torch.__version__) < version.parse("1.10"):
|
|
log_warn("Disabling CUDA Graph as pytorch < 1.10")
|
|
return False
|
|
return True
|
|
|
|
|
|
def is_torch_compile_available():
|
|
force_compile = os.environ.get("CUROBO_TORCH_COMPILE_FORCE")
|
|
if force_compile is not None and bool(int(force_compile)):
|
|
return True
|
|
if version.parse(torch.__version__) < version.parse("2.0"):
|
|
log_info("Disabling torch.compile as pytorch < 2.0")
|
|
return False
|
|
|
|
env_variable = os.environ.get("CUROBO_TORCH_COMPILE_DISABLE")
|
|
|
|
if env_variable is None:
|
|
log_info("Environment variable for CUROBO_TORCH_COMPILE is not set, Disabling.")
|
|
|
|
return False
|
|
|
|
if bool(int(env_variable)):
|
|
log_info("Environment variable for CUROBO_TORCH_COMPILE is set to Disable")
|
|
return False
|
|
|
|
log_info("Environment variable for CUROBO_TORCH_COMPILE is set to Enable")
|
|
|
|
try:
|
|
torch.compile
|
|
except:
|
|
log_info("Could not find torch.compile, disabling Torch Compile.")
|
|
return False
|
|
try:
|
|
torch._dynamo
|
|
except:
|
|
log_info("Could not find torch._dynamo, disabling Torch Compile.")
|
|
return False
|
|
try:
|
|
# Third Party
|
|
import triton
|
|
except:
|
|
log_info("Could not find triton, disabling Torch Compile.")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def get_torch_compile_options() -> dict:
|
|
options = {}
|
|
if is_torch_compile_available():
|
|
# Third Party
|
|
from torch._inductor import config
|
|
|
|
torch._dynamo.config.suppress_errors = True
|
|
use_options = {
|
|
"max_autotune": True,
|
|
"use_mixed_mm": True,
|
|
"conv_1x1_as_mm": True,
|
|
"coordinate_descent_tuning": True,
|
|
"epilogue_fusion": False,
|
|
"coordinate_descent_check_all_directions": True,
|
|
"force_fuse_int_mm_with_mul": True,
|
|
"triton.cudagraphs": False,
|
|
"aggressive_fusion": True,
|
|
"split_reductions": False,
|
|
"worker_start_method": "spawn",
|
|
}
|
|
for k in use_options.keys():
|
|
if hasattr(config, k):
|
|
options[k] = use_options[k]
|
|
else:
|
|
log_info("Not found in torch.compile: " + k)
|
|
return options
|
|
|
|
|
|
def disable_torch_compile_global():
|
|
if is_torch_compile_available():
|
|
torch._dynamo.config.disable = True
|
|
return True
|
|
return False
|
|
|
|
|
|
def set_torch_compile_global_options():
|
|
if is_torch_compile_available():
|
|
# Third Party
|
|
from torch._inductor import config
|
|
|
|
torch._dynamo.config.suppress_errors = True
|
|
if hasattr(config, "conv_1x1_as_mm"):
|
|
torch._inductor.config.conv_1x1_as_mm = True
|
|
if hasattr(config, "coordinate_descent_tuning"):
|
|
torch._inductor.config.coordinate_descent_tuning = True
|
|
if hasattr(config, "epilogue_fusion"):
|
|
torch._inductor.config.epilogue_fusion = False
|
|
if hasattr(config, "coordinate_descent_check_all_directions"):
|
|
torch._inductor.config.coordinate_descent_check_all_directions = True
|
|
if hasattr(config, "force_fuse_int_mm_with_mul"):
|
|
torch._inductor.config.force_fuse_int_mm_with_mul = True
|
|
if hasattr(config, "use_mixed_mm"):
|
|
torch._inductor.config.use_mixed_mm = True
|
|
return True
|
|
return False
|
|
|
|
|
|
def get_torch_jit_decorator(
|
|
force_jit: bool = False, dynamic: bool = True, only_valid_for_compile: bool = False
|
|
):
|
|
if not force_jit and is_torch_compile_available():
|
|
return torch.compile(options=get_torch_compile_options(), dynamic=dynamic)
|
|
elif not only_valid_for_compile:
|
|
return torch.jit.script
|
|
else:
|
|
return empty_decorator
|
|
|
|
|
|
def is_lru_cache_avaiable():
|
|
use_lru_cache = os.environ.get("CUROBO_USE_LRU_CACHE")
|
|
if use_lru_cache is not None:
|
|
return bool(int(use_lru_cache))
|
|
log_info("Environment variable for CUROBO_USE_LRU_CACHE is not set, Enabling as default.")
|
|
return False
|
|
|
|
|
|
def get_cache_fn_decorator(maxsize: Optional[int] = None):
|
|
if is_lru_cache_avaiable():
|
|
return lru_cache(maxsize=maxsize)
|
|
else:
|
|
return empty_decorator
|
|
|
|
|
|
def empty_decorator(function):
|
|
return function
|