gfn.utils.distributions

Module Contents

Classes

UnsqueezedCategorical

Samples froma categorical distribution with an unsqueezed final dimension.

class gfn.utils.distributions.UnsqueezedCategorical

Bases: torch.distributions.Categorical

Samples froma categorical distribution with an unsqueezed final dimension.

Samples are unsqueezed to be of shape (batch_size, 1) instead of (batch_size,).

This is used in DiscretePFEstimator and DiscretePBEstimator, which in turn are used in Sampler.

This helper class facilitates representing actions, for discrete environments, which when implemented with the DiscreteActions class (see gfn/env.py::DiscreteEnv), use an `action_shape = (1,). Therefore, according to gfn/actions.py::Actions, tensors representing actions in discrete environments should be of shape (batch_shape, 1).

log_prob(sample)

Returns the log probabilities of an unsqueezed sample.

Parameters

sample (torchtyping.TensorType[sample_shape, 1]) –

Return type

torchtyping.TensorType[sample_shape]

sample(sample_shape=torch.Size())

Sample actions with an unsqueezed final dimension.

Return type

torchtyping.TensorType[UnsqueezedCategorical.sample.sample_shape, 1]