gfn.gflownet.flow_matching

Module Contents

Classes

FMGFlowNet

Flow Matching GFlowNet, with edge flow estimator.

class gfn.gflownet.flow_matching.FMGFlowNet(logF, alpha=1.0)

Bases: gfn.gflownet.base.GFlowNet[Tuple[gfn.states.DiscreteStates, gfn.states.DiscreteStates]]

Flow Matching GFlowNet, with edge flow estimator.

\(\mathcal{O}_{edge}\) is the set of functions from the non-terminating edges to \(\mathbb{R}^+\). Which is equivalent to the set of functions from the internal nodes (i.e. without \(s_f\)) to \((\mathbb{R})^{n_actions}\), without the exit action (No need for positivity if we parametrize log-flows).

The loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).

Parameters
logF

LogEdgeFlowEstimator

alpha

weight for the reward matching loss.

flow_matching_loss(env, states)

Computes the FM for the provided states.

The Flow Matching loss is defined as the log-sum incoming flows minus log-sum outgoing flows. The states should not include \(s_0\). The batch shape should be (n_states,). As of now, only discrete environments are handled.

Raises
  • AssertionError – If the batch shape is not linear.

  • AssertionError – If any state is at \(s_0\).

Parameters
Return type

torchtyping.TensorType[n_trajectories, torch.float]

loss(env, states_tuple)

Given a batch of non-terminal and terminal states, compute a loss.

Unlike the GFlowNets Foundations paper, we allow more flexibility by passing a tuple of states, the first one being the internal states of the trajectories (i.e. non-terminal states), and the second one being the terminal states of the trajectories.

Parameters
Return type

torchtyping.TensorType[0, float]

reward_matching_loss(env, terminating_states)

Calculates the reward matching loss from the terminating states.

Parameters
Return type

torchtyping.TensorType[0, float]

sample_trajectories(env, n_samples=1000)

Sample a specific number of complete trajectories.

Parameters
  • env (gfn.env.Env) – the environment to sample trajectories from.

  • n_samples (int) – number of trajectories to be sampled.

Returns

sampled trajectories object.

Return type

Trajectories

to_training_samples(trajectories)

Converts a batch of trajectories into a batch of training samples.

Parameters

trajectories (gfn.containers.Trajectories) –

Return type

tuple[gfn.states.DiscreteStates, gfn.states.DiscreteStates]