gfn.utils.modules

This file contains some examples of modules that can be used with GFN.

Module Contents

Classes

DiscreteUniform

Implements a uniform distribution over discrete actions.

NeuralNet

Implements a basic MLP.

Tabular

Implements a tabular policy.

class gfn.utils.modules.DiscreteUniform(output_dim)

Bases: torch.nn.Module

Implements a uniform distribution over discrete actions.

It uses a zero function approximator (a function that always outputs 0) to be used as logits by a DiscretePBEstimator.

Parameters

output_dim (int) –

output_dim

The size of the output space.

forward(preprocessed_states)
Parameters

preprocessed_states (torchtyping.TensorType[batch_shape, input_dim, float]) –

Return type

torchtyping.TensorType[batch_shape, output_dim, float]

class gfn.utils.modules.NeuralNet(input_dim, output_dim, hidden_dim=256, n_hidden_layers=2, activation_fn='relu', torso=None)

Bases: torch.nn.Module

Implements a basic MLP.

Parameters
  • input_dim (int) –

  • output_dim (int) –

  • hidden_dim (Optional[int]) –

  • n_hidden_layers (Optional[int]) –

  • activation_fn (Optional[Literal[relu, tanh, elu]]) –

  • torso (Optional[torch.nn.Module]) –

forward(preprocessed_states)

Forward method for the neural network.

Parameters

preprocessed_states (torchtyping.TensorType[batch_shape, input_dim, float]) – a batch of states appropriately preprocessed for ingestion by the MLP.

Return type

torchtyping.TensorType[batch_shape, output_dim, float]

Returns: out, a set of continuous variables.

class gfn.utils.modules.Tabular(n_states, output_dim)

Bases: torch.nn.Module

Implements a tabular policy.

This class is only compatible with the EnumPreprocessor.

Parameters
  • n_states (int) –

  • output_dim (int) –

table

a tensor with dimensions [n_states, output_dim].

device

the device that holds this policy.

forward(preprocessed_states)
Parameters

preprocessed_states (torchtyping.TensorType[batch_shape, input_dim, float]) –

Return type

torchtyping.TensorType[batch_shape, output_dim, float]