gfn.utils
Submodules
Package Contents
Classes
Implements a basic MLP. |
|
Implements a tabular policy. |
- class gfn.utils.NeuralNet(input_dim, output_dim, hidden_dim=256, n_hidden_layers=2, activation_fn='relu', torso=None)
Bases:
torch.nn.ModuleImplements a basic MLP.
- Parameters
input_dim (int) –
output_dim (int) –
hidden_dim (Optional[int]) –
n_hidden_layers (Optional[int]) –
activation_fn (Optional[Literal[relu, tanh, elu]]) –
torso (Optional[torch.nn.Module]) –
- forward(preprocessed_states)
Forward method for the neural network.
- Parameters
preprocessed_states (torchtyping.TensorType[batch_shape, input_dim, float]) – a batch of states appropriately preprocessed for ingestion by the MLP.
- Return type
torchtyping.TensorType[batch_shape, output_dim, float]
Returns: out, a set of continuous variables.
- class gfn.utils.Tabular(n_states, output_dim)
Bases:
torch.nn.ModuleImplements a tabular policy.
This class is only compatible with the EnumPreprocessor.
- Parameters
n_states (int) –
output_dim (int) –
- table
a tensor with dimensions [n_states, output_dim].
- device
the device that holds this policy.
- forward(preprocessed_states)
- Parameters
preprocessed_states (torchtyping.TensorType[batch_shape, input_dim, float]) –
- Return type
torchtyping.TensorType[batch_shape, output_dim, float]