gfn.utils.distributions
Module Contents
Classes
Samples froma categorical distribution with an unsqueezed final dimension. |
- class gfn.utils.distributions.UnsqueezedCategorical
Bases:
torch.distributions.CategoricalSamples froma categorical distribution with an unsqueezed final dimension.
Samples are unsqueezed to be of shape (batch_size, 1) instead of (batch_size,).
This is used in DiscretePFEstimator and DiscretePBEstimator, which in turn are used in Sampler.
This helper class facilitates representing actions, for discrete environments, which when implemented with the DiscreteActions class (see gfn/env.py::DiscreteEnv), use an `action_shape = (1,). Therefore, according to gfn/actions.py::Actions, tensors representing actions in discrete environments should be of shape (batch_shape, 1).
- log_prob(sample)
Returns the log probabilities of an unsqueezed sample.
- Parameters
sample (torchtyping.TensorType[sample_shape, 1]) –
- Return type
torchtyping.TensorType[sample_shape]
- sample(sample_shape=torch.Size())
Sample actions with an unsqueezed final dimension.
- Return type
torchtyping.TensorType[UnsqueezedCategorical.sample.sample_shape, 1]