release repository
This commit is contained in:
73
src/curobo/util/tensor_util.py
Normal file
73
src/curobo/util/tensor_util.py
Normal file
@@ -0,0 +1,73 @@
|
||||
#
|
||||
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
#
|
||||
# Standard Library
|
||||
from typing import List
|
||||
|
||||
# Third Party
|
||||
import torch
|
||||
|
||||
|
||||
def check_tensor_shapes(new_tensor: torch.Tensor, mem_tensor: torch.Tensor):
|
||||
if not isinstance(mem_tensor, torch.Tensor):
|
||||
return False
|
||||
if len(mem_tensor.shape) != len(new_tensor.shape):
|
||||
return False
|
||||
if mem_tensor.shape == new_tensor.shape:
|
||||
return True
|
||||
|
||||
|
||||
def copy_tensor(new_tensor: torch.Tensor, mem_tensor: torch.Tensor):
|
||||
if check_tensor_shapes(new_tensor, mem_tensor):
|
||||
mem_tensor.copy_(new_tensor)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def copy_if_not_none(x):
|
||||
"""Clones x if it's not None.
|
||||
TODO: Rename this to clone_if_not_none
|
||||
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
if x is not None:
|
||||
return x.clone()
|
||||
return None
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def cat_sum(tensor_list: List[torch.Tensor]):
|
||||
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=0)
|
||||
return cat_tensor
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def cat_max(tensor_list: List[torch.Tensor]):
|
||||
cat_tensor = torch.max(torch.stack(tensor_list, dim=0), dim=0)[0]
|
||||
return cat_tensor
|
||||
|
||||
|
||||
def tensor_repeat_seeds(tensor, num_seeds):
|
||||
return (
|
||||
tensor.view(tensor.shape[0], 1, tensor.shape[-1])
|
||||
.repeat(1, num_seeds, 1)
|
||||
.reshape(tensor.shape[0] * num_seeds, tensor.shape[-1])
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fd_tensor(p: torch.Tensor, dt: torch.Tensor):
|
||||
out = ((torch.roll(p, -1, -2) - p) * (1 / dt).unsqueeze(-1))[..., :-1, :]
|
||||
return out
|
||||
Reference in New Issue
Block a user