101 lines
2.6 KiB
Python
101 lines
2.6 KiB
Python
#
|
|
# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
# property and proprietary rights in and to this material, related
|
|
# documentation and any modifications thereto. Any use, reproduction,
|
|
# disclosure or distribution of this material and related documentation
|
|
# without an express license agreement from NVIDIA CORPORATION or
|
|
# its affiliates is strictly prohibited.
|
|
#
|
|
# Standard Library
|
|
from typing import List
|
|
|
|
# Third Party
|
|
import torch
|
|
|
|
# CuRobo
|
|
from curobo.util.torch_utils import get_torch_jit_decorator
|
|
|
|
|
|
def check_tensor_shapes(new_tensor: torch.Tensor, mem_tensor: torch.Tensor):
|
|
if not isinstance(mem_tensor, torch.Tensor):
|
|
return False
|
|
if len(mem_tensor.shape) != len(new_tensor.shape):
|
|
return False
|
|
if mem_tensor.shape == new_tensor.shape:
|
|
return True
|
|
|
|
|
|
def copy_tensor(new_tensor: torch.Tensor, mem_tensor: torch.Tensor):
|
|
if check_tensor_shapes(new_tensor, mem_tensor):
|
|
mem_tensor.copy_(new_tensor)
|
|
return True
|
|
return False
|
|
|
|
|
|
def copy_if_not_none(new_tensor, ref_tensor):
|
|
"""Clones x if it's not None.
|
|
TODO: Rename this to clone_if_not_none
|
|
|
|
|
|
Args:
|
|
x (torch.Tensor): _description_
|
|
|
|
Returns:
|
|
_type_: _description_
|
|
"""
|
|
if ref_tensor is not None and new_tensor is not None:
|
|
ref_tensor.copy_(new_tensor)
|
|
elif ref_tensor is None and new_tensor is not None:
|
|
ref_tensor = new_tensor
|
|
|
|
return ref_tensor
|
|
|
|
|
|
def clone_if_not_none(x):
|
|
"""Clones x if it's not None.
|
|
|
|
|
|
Args:
|
|
x (torch.Tensor): _description_
|
|
|
|
Returns:
|
|
_type_: _description_
|
|
"""
|
|
if x is not None:
|
|
return x.clone()
|
|
return None
|
|
|
|
|
|
@get_torch_jit_decorator()
|
|
def cat_sum(tensor_list: List[torch.Tensor]):
|
|
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=0)
|
|
return cat_tensor
|
|
|
|
|
|
@get_torch_jit_decorator()
|
|
def cat_sum_horizon(tensor_list: List[torch.Tensor]):
|
|
cat_tensor = torch.sum(torch.stack(tensor_list, dim=0), dim=(0, -1))
|
|
return cat_tensor
|
|
|
|
|
|
@get_torch_jit_decorator()
|
|
def cat_max(tensor_list: List[torch.Tensor]):
|
|
cat_tensor = torch.max(torch.stack(tensor_list, dim=0), dim=0)[0]
|
|
return cat_tensor
|
|
|
|
|
|
def tensor_repeat_seeds(tensor, num_seeds):
|
|
return (
|
|
tensor.view(tensor.shape[0], 1, tensor.shape[-1])
|
|
.repeat(1, num_seeds, 1)
|
|
.reshape(tensor.shape[0] * num_seeds, tensor.shape[-1])
|
|
)
|
|
|
|
|
|
@get_torch_jit_decorator()
|
|
def fd_tensor(p: torch.Tensor, dt: torch.Tensor):
|
|
out = ((torch.roll(p, -1, -2) - p) * (1 / dt).unsqueeze(-1))[..., :-1, :]
|
|
return out
|