synthcity.plugins.core.models.tabular_ddpm.utils module

approx_standard_normal_cdf(x: torch.Tensor) torch.Tensor

A fast approximation of the cumulative distribution function of the standard normal.

discretized_gaussian_log_likelihood(x: torch.Tensor, *, means: torch.Tensor, log_scales: torch.Tensor) torch.Tensor

Compute the log-likelihood of a Gaussian distribution discretizing to a given image.

Parameters
  • x – the target images. It is assumed that this was uint8 values, rescaled to the range [-1, 1].

  • means – the Gaussian mean Tensor.

  • log_scales – the Gaussian log stddev Tensor.

Returns

a tensor like x of log probabilities (in nats).

index_to_log_onehot(x: torch.Tensor, num_classes: numpy.ndarray) torch.Tensor
log_1_min_a(a: torch.Tensor) torch.Tensor
log_add_exp(a: torch.Tensor, b: torch.Tensor) torch.Tensor

Numerically stable log(exp(a) + exp(b)).

log_categorical(log_x_start: torch.Tensor, log_prob: torch.Tensor) torch.Tensor
mean_flat(tensor: torch.Tensor) torch.Tensor

Take the mean over all non-batch dimensions.

normal_kl(mean1: torch.Tensor, logvar1: torch.Tensor, mean2: torch.Tensor, logvar2: torch.Tensor) torch.Tensor

Compute the KL divergence between two gaussians.

Shapes are automatically broadcasted, so batches can be compared to scalars, among other use cases.

ohe_to_categories(ohe: torch.Tensor, K: numpy.ndarray) torch.Tensor
perm_and_expand(a: torch.Tensor, t: torch.Tensor, x_shape: tuple) torch.Tensor

Permutes a tensor in the order specified by t and expands it to x_shape.

sum_except_batch(x: torch.Tensor, num_dims: int = 1) torch.Tensor

Sums all dimensions except the first.

Parameters
  • x – Tensor, shape (batch_size, …)

  • num_dims – int, number of batch dims (default=1)

Returns

Tensor, shape (batch_size,)

Return type

x_sum