gfn.modules
Module Contents
Classes
Container for forward and backward policy estimators for discrete environments. |
|
Base class for modules mapping states distributions. |
|
Base class for modules mapping states distributions. |
- class gfn.modules.DiscretePolicyEstimator(module, n_actions, preprocessor, is_backward=False)
Bases:
GFNModuleContainer for forward and backward policy estimators for discrete environments.
\(s \mapsto (P_F(s' \mid s))_{s' \in Children(s)}\).
or
\(s \mapsto (P_B(s' \mid s))_{s' \in Parents(s)}\).
- Parameters
module (torch.nn.Module) –
n_actions (int) –
preprocessor (Preprocessor | None) –
is_backward (bool) –
- temperature
scalar to divide the logits by before softmax.
- sf_bias
scalar to subtract from the exit action logit before dividing by temperature.
- epsilon
with probability epsilon, a random action is chosen.
- expected_output_dim()
Expected output dimension of the module.
- Return type
int
- to_probability_distribution(states, module_output, temperature=1.0, sf_bias=0.0, epsilon=0.0)
Returns a probability distribution given a batch of states and module output.
- Parameters
temperature (float) – scalar to divide the logits by before softmax. Does nothing if set to 1.0 (default), in which case it’s on policy.
sf_bias (float) – scalar to subtract from the exit action logit before dividing by temperature. Does nothing if set to 0.0 (default), in which case it’s on policy.
epsilon (float) – with probability epsilon, a random action is chosen. Does nothing if set to 0.0 (default), in which case it’s on policy.
states (gfn.states.DiscreteStates) –
module_output (torchtyping.TensorType[batch_shape, output_dim, float]) –
- Return type
torch.distributions.Categorical
- class gfn.modules.GFNModule(module, preprocessor=None, is_backward=False)
Bases:
abc.ABC,torch.nn.ModuleBase class for modules mapping states distributions.
Training a GFlowNet requires parameterizing one or more of the following functions: - \(s \mapsto (\log F(s \rightarrow s'))_{s' \in Children(s)}\) - \(s \mapsto (P_F(s' \mid s))_{s' \in Children(s)}\) - \(s' \mapsto (P_B(s \mid s'))_{s \in Parents(s')}\) - \(s \mapsto (\log F(s))_{s \in States}\)
This class is the base class for all such function estimators. The estimators need to encapsulate a nn.Module, which takes a a batch of preprocessed states as input and outputs a batch of outputs of the desired shape. When the goal is to represent a probability distribution, the outputs would correspond to the parameters of the distribution, e.g. logits for a categorical distribution for discrete environments.
The call method is used to output logits, or the parameters to distributions. Otherwise, one can overwrite and use the to_probability_distribution() method to directly output a probability distribution.
The preprocessor is also encapsulated in the estimator. These function estimators implement the __call__ method, which takes States objects as inputs and calls the module on the preprocessed states.
- Parameters
module (torch.nn.Module) –
preprocessor (Preprocessor | None) –
is_backward (bool) –
- preprocessor
Preprocessor object that transforms raw States objects to tensors that can be used as input to the module. Optional, defaults to IdentityPreprocessor.
- module
The module to use. If the module is a Tabular module (from gfn.utils.modules), then the environment preprocessor needs to be an EnumPreprocessor.
- preprocessor
Preprocessor from the environment.
- _output_dim_is_checked
Flag for tracking whether the output dimenions of the states (after being preprocessed and transformed by the modules) have been verified.
- abstract property expected_output_dim: int
Expected output dimension of the module.
- Return type
int
- __repr__()
Return repr(self).
- check_output_dim(module_output)
Check that the output of the module has the correct shape. Raises an error if not.
- Parameters
module_output (torchtyping.TensorType[batch_shape, output_dim, float]) –
- Return type
None
- forward(states)
- Parameters
states (gfn.states.States) –
- Return type
torchtyping.TensorType[batch_shape, output_dim, float]
- abstract to_probability_distribution(states, module_output, *args)
Transform the output of the module into a probability distribution.
The kwargs modify a base distribution, for example to encourage exploration.
Not all modules must implement this method, but it is required to define a policy from a module’s outputs. See DiscretePolicyEstimator for an example using a categorical distribution, but note this can be done for all continuous distributions as well.
- Parameters
states (gfn.states.States) –
module_output (torchtyping.TensorType[batch_shape, output_dim, float]) –
- Return type
torch.distributions.Distribution
- class gfn.modules.ScalarEstimator(module, preprocessor=None, is_backward=False)
Bases:
GFNModuleBase class for modules mapping states distributions.
Training a GFlowNet requires parameterizing one or more of the following functions: - \(s \mapsto (\log F(s \rightarrow s'))_{s' \in Children(s)}\) - \(s \mapsto (P_F(s' \mid s))_{s' \in Children(s)}\) - \(s' \mapsto (P_B(s \mid s'))_{s \in Parents(s')}\) - \(s \mapsto (\log F(s))_{s \in States}\)
This class is the base class for all such function estimators. The estimators need to encapsulate a nn.Module, which takes a a batch of preprocessed states as input and outputs a batch of outputs of the desired shape. When the goal is to represent a probability distribution, the outputs would correspond to the parameters of the distribution, e.g. logits for a categorical distribution for discrete environments.
The call method is used to output logits, or the parameters to distributions. Otherwise, one can overwrite and use the to_probability_distribution() method to directly output a probability distribution.
The preprocessor is also encapsulated in the estimator. These function estimators implement the __call__ method, which takes States objects as inputs and calls the module on the preprocessed states.
- Parameters
module (torch.nn.Module) –
preprocessor (Preprocessor | None) –
is_backward (bool) –
- preprocessor
Preprocessor object that transforms raw States objects to tensors that can be used as input to the module. Optional, defaults to IdentityPreprocessor.
- module
The module to use. If the module is a Tabular module (from gfn.utils.modules), then the environment preprocessor needs to be an EnumPreprocessor.
- preprocessor
Preprocessor from the environment.
- _output_dim_is_checked
Flag for tracking whether the output dimenions of the states (after being preprocessed and transformed by the modules) have been verified.
- expected_output_dim()
Expected output dimension of the module.
- Return type
int