bug fix for stomp sampling
This commit is contained in:
@@ -510,7 +510,21 @@ class HaltonGenerator:
|
|||||||
def gaussian_transform(
|
def gaussian_transform(
|
||||||
uniform_samples: torch.Tensor, proj_mat: torch.Tensor, i_mat: torch.Tensor, variance: float
|
uniform_samples: torch.Tensor, proj_mat: torch.Tensor, i_mat: torch.Tensor, variance: float
|
||||||
):
|
):
|
||||||
gaussian_halton_samples = proj_mat * torch.erfinv(2 * uniform_samples - 1)
|
"""Compute a guassian transform of uniform samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uniform_samples (torch.Tensor): uniform samples in the range [0,1].
|
||||||
|
proj_mat (torch.Tensor): _description_
|
||||||
|
i_mat (torch.Tensor): _description_
|
||||||
|
variance (float): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
# since erfinv returns inf when value is -1 or +1, we scale the input to not have
|
||||||
|
# these values.
|
||||||
|
changed_samples = 1.99 * uniform_samples - 0.99
|
||||||
|
gaussian_halton_samples = proj_mat * torch.erfinv(changed_samples)
|
||||||
i_mat = i_mat * variance
|
i_mat = i_mat * variance
|
||||||
gaussian_halton_samples = torch.matmul(gaussian_halton_samples, i_mat)
|
gaussian_halton_samples = torch.matmul(gaussian_halton_samples, i_mat)
|
||||||
return gaussian_halton_samples
|
return gaussian_halton_samples
|
||||||
|
|||||||
Reference in New Issue
Block a user