gfn.gflownet.flow_matching
Module Contents
Classes
Flow Matching GFlowNet, with edge flow estimator. |
- class gfn.gflownet.flow_matching.FMGFlowNet(logF, alpha=1.0)
Bases:
gfn.gflownet.base.GFlowNet[Tuple[gfn.states.DiscreteStates,gfn.states.DiscreteStates]]Flow Matching GFlowNet, with edge flow estimator.
\(\mathcal{O}_{edge}\) is the set of functions from the non-terminating edges to \(\mathbb{R}^+\). Which is equivalent to the set of functions from the internal nodes (i.e. without \(s_f\)) to \((\mathbb{R})^{n_actions}\), without the exit action (No need for positivity if we parametrize log-flows).
The loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).
- Parameters
logF (gfn.modules.DiscretePolicyEstimator) –
alpha (float) –
- logF
LogEdgeFlowEstimator
- alpha
weight for the reward matching loss.
- flow_matching_loss(env, states)
Computes the FM for the provided states.
The Flow Matching loss is defined as the log-sum incoming flows minus log-sum outgoing flows. The states should not include \(s_0\). The batch shape should be (n_states,). As of now, only discrete environments are handled.
- Raises
AssertionError – If the batch shape is not linear.
AssertionError – If any state is at \(s_0\).
- Parameters
env (gfn.env.Env) –
states (gfn.states.DiscreteStates) –
- Return type
torchtyping.TensorType[n_trajectories, torch.float]
- loss(env, states_tuple)
Given a batch of non-terminal and terminal states, compute a loss.
Unlike the GFlowNets Foundations paper, we allow more flexibility by passing a tuple of states, the first one being the internal states of the trajectories (i.e. non-terminal states), and the second one being the terminal states of the trajectories.
- Parameters
env (gfn.env.Env) –
states_tuple (Tuple[gfn.states.DiscreteStates, gfn.states.DiscreteStates]) –
- Return type
torchtyping.TensorType[0, float]
- reward_matching_loss(env, terminating_states)
Calculates the reward matching loss from the terminating states.
- Parameters
env (gfn.env.Env) –
terminating_states (gfn.states.DiscreteStates) –
- Return type
torchtyping.TensorType[0, float]
- sample_trajectories(env, n_samples=1000)
Sample a specific number of complete trajectories.
- Parameters
env (gfn.env.Env) – the environment to sample trajectories from.
n_samples (int) – number of trajectories to be sampled.
- Returns
sampled trajectories object.
- Return type
- to_training_samples(trajectories)
Converts a batch of trajectories into a batch of training samples.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type