gfn.utils

Submodules

Package Contents

Classes

NeuralNet

Implements a basic MLP.

Tabular

Implements a tabular policy.

class gfn.utils.NeuralNet(input_dim, output_dim, hidden_dim=256, n_hidden_layers=2, activation_fn='relu', torso=None)

Bases: torch.nn.Module

Implements a basic MLP.

Parameters
  • input_dim (int) –

  • output_dim (int) –

  • hidden_dim (Optional[int]) –

  • n_hidden_layers (Optional[int]) –

  • activation_fn (Optional[Literal[relu, tanh, elu]]) –

  • torso (Optional[torch.nn.Module]) –

forward(preprocessed_states)

Forward method for the neural network.

Parameters

preprocessed_states (torchtyping.TensorType[batch_shape, input_dim, float]) – a batch of states appropriately preprocessed for ingestion by the MLP.

Return type

torchtyping.TensorType[batch_shape, output_dim, float]

Returns: out, a set of continuous variables.

class gfn.utils.Tabular(n_states, output_dim)

Bases: torch.nn.Module

Implements a tabular policy.

This class is only compatible with the EnumPreprocessor.

Parameters
  • n_states (int) –

  • output_dim (int) –

table

a tensor with dimensions [n_states, output_dim].

device

the device that holds this policy.

forward(preprocessed_states)
Parameters

preprocessed_states (torchtyping.TensorType[batch_shape, input_dim, float]) –

Return type

torchtyping.TensorType[batch_shape, output_dim, float]