release repository

This commit is contained in:
Balakumar Sundaralingam
2023-10-26 04:17:19 -07:00
commit 07e6ccfc91
287 changed files with 70659 additions and 0 deletions

View File

@@ -0,0 +1,73 @@
#
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Standard Library
from typing import List
# Third Party
import torch
def check_tensor_shapes(new_tensor: torch.Tensor, mem_tensor: torch.Tensor):
if not isinstance(mem_tensor, torch.Tensor):
return False
if len(mem_tensor.shape) != len(new_tensor.shape):
return False
if mem_tensor.shape == new_tensor.shape:
return True
def copy_tensor(new_tensor: torch.Tensor, mem_tensor: torch.Tensor):
if check_tensor_shapes(new_tensor, mem_tensor):
mem_tensor.copy_(new_tensor)
return True
return False
def copy_if_not_none(x):
"""Clones x if it's not None.
TODO: Rename this to clone_if_not_none
Args:
x (torch.Tensor): _description_
Returns:
_type_: _description_
"""
if x is not None:
return x.clone()
return None
@torch.jit.script
def cat_sum(tensor_list: List[torch.Tensor]):
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=0)
return cat_tensor
@torch.jit.script
def cat_max(tensor_list: List[torch.Tensor]):
cat_tensor = torch.max(torch.stack(tensor_list, dim=0), dim=0)[0]
return cat_tensor
def tensor_repeat_seeds(tensor, num_seeds):
return (
tensor.view(tensor.shape[0], 1, tensor.shape[-1])
.repeat(1, num_seeds, 1)
.reshape(tensor.shape[0] * num_seeds, tensor.shape[-1])
)
@torch.jit.script
def fd_tensor(p: torch.Tensor, dt: torch.Tensor):
out = ((torch.roll(p, -1, -2) - p) * (1 / dt).unsqueeze(-1))[..., :-1, :]
return out