gfn.gflownet

Submodules

Package Contents

Classes

DBGFlowNet

The Detailed Balance GFlowNet.

FMGFlowNet

Flow Matching GFlowNet, with edge flow estimator.

GFlowNet

Abstract Base Class for GFlowNets.

LogPartitionVarianceGFlowNet

Dataclass which holds the logZ estimate for the Log Partition Variance loss.

ModifiedDBGFlowNet

The Modified Detailed Balance GFlowNet. Only applicable to environments where

PFBasedGFlowNet

Base class for gflownets that explicitly uses $P_F$.

SubTBGFlowNet

GFlowNet for the Sub Trajectory Balance Loss.

TBGFlowNet

Holds the logZ estimate for the Trajectory Balance loss.

TrajectoryBasedGFlowNet

Base class for gflownets that explicitly uses $P_F$.

class gfn.gflownet.DBGFlowNet(pf, pb, logF, on_policy=False, forward_looking=False)

Bases: gfn.gflownet.base.PFBasedGFlowNet[gfn.containers.Transitions]

The Detailed Balance GFlowNet.

Corresponds to \(\mathcal{O}_{PF} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1\) is the set of functions from the internal states (no \(s_f\)) to \(\mathbb{R}^+\) (which we parametrize with logs, to avoid the non-negativity constraint), and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.logit_PB is a fixed DiscretePBEstimator.

Parameters
logF

a ScalarEstimator instance.

on_policy

boolean indicating whether we need to reevaluate the log probs.

forward_looking

whether to implement the forward looking GFN loss.

get_scores(env, transitions)

Given a batch of transitions, calculate the scores.

Parameters
Raises
  • ValueError – when supplied with backward transitions.

  • AssertionError – when log rewards of transitions are None.

Return type

Tuple[torchtyping.TensorType[n_transitions, float], torchtyping.TensorType[n_transitions, float], torchtyping.TensorType[n_transitions, float]]

loss(env, transitions)

Detailed balance loss.

The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).

Parameters
Return type

torchtyping.TensorType[0, float]

to_training_samples(trajectories)
Parameters

trajectories (gfn.containers.Trajectories) –

Return type

gfn.containers.Transitions

class gfn.gflownet.FMGFlowNet(logF, alpha=1.0)

Bases: gfn.gflownet.base.GFlowNet[Tuple[gfn.states.DiscreteStates, gfn.states.DiscreteStates]]

Flow Matching GFlowNet, with edge flow estimator.

\(\mathcal{O}_{edge}\) is the set of functions from the non-terminating edges to \(\mathbb{R}^+\). Which is equivalent to the set of functions from the internal nodes (i.e. without \(s_f\)) to \((\mathbb{R})^{n_actions}\), without the exit action (No need for positivity if we parametrize log-flows).

The loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).

Parameters
logF

LogEdgeFlowEstimator

alpha

weight for the reward matching loss.

flow_matching_loss(env, states)

Computes the FM for the provided states.

The Flow Matching loss is defined as the log-sum incoming flows minus log-sum outgoing flows. The states should not include \(s_0\). The batch shape should be (n_states,). As of now, only discrete environments are handled.

Raises
  • AssertionError – If the batch shape is not linear.

  • AssertionError – If any state is at \(s_0\).

Parameters
Return type

torchtyping.TensorType[n_trajectories, torch.float]

loss(env, states_tuple)

Given a batch of non-terminal and terminal states, compute a loss.

Unlike the GFlowNets Foundations paper, we allow more flexibility by passing a tuple of states, the first one being the internal states of the trajectories (i.e. non-terminal states), and the second one being the terminal states of the trajectories.

Parameters
Return type

torchtyping.TensorType[0, float]

reward_matching_loss(env, terminating_states)

Calculates the reward matching loss from the terminating states.

Parameters
Return type

torchtyping.TensorType[0, float]

sample_trajectories(env, n_samples=1000)

Sample a specific number of complete trajectories.

Parameters
  • env (gfn.env.Env) – the environment to sample trajectories from.

  • n_samples (int) – number of trajectories to be sampled.

Returns

sampled trajectories object.

Return type

Trajectories

to_training_samples(trajectories)

Converts a batch of trajectories into a batch of training samples.

Parameters

trajectories (gfn.containers.Trajectories) –

Return type

tuple[gfn.states.DiscreteStates, gfn.states.DiscreteStates]

class gfn.gflownet.GFlowNet

Bases: abc.ABC, torch.nn.Module, Generic[TrainingSampleType]

Abstract Base Class for GFlowNets.

A formal definition of GFlowNets is given in Sec. 3 of [GFlowNet Foundations](https://arxiv.org/pdf/2111.09266).

abstract loss(env, training_objects)

Computes the loss given the training objects.

Parameters

env (gfn.env.Env) –

sample_terminating_states(env, n_samples)

Rolls out the parametrization’s policy and returns the terminating states.

Parameters
  • env (gfn.env.Env) – the environment to sample terminating states from.

  • n_samples (int) – number of terminating states to be sampled.

Returns

sampled terminating states object.

Return type

States

abstract sample_trajectories(env, n_samples)

Sample a specific number of complete trajectories.

Parameters
  • env (gfn.env.Env) – the environment to sample trajectories from.

  • n_samples (int) – number of trajectories to be sampled.

Returns

sampled trajectories object.

Return type

Trajectories

abstract to_training_samples(trajectories)

Converts trajectories to training samples. The type depends on the GFlowNet.

Parameters

trajectories (gfn.containers.Trajectories) –

Return type

TrainingSampleType

class gfn.gflownet.LogPartitionVarianceGFlowNet(pf, pb, on_policy=False, log_reward_clip_min=-12)

Bases: gfn.gflownet.base.TrajectoryBasedGFlowNet

Dataclass which holds the logZ estimate for the Log Partition Variance loss.

log_reward_clip_min

minimal value to clamp the reward to.

Raises

ValueError – if the loss is NaN.

Parameters
loss(env, trajectories)

Log Partition Variance loss.

This method is described in section 3.2 of [ROBUST SCHEDULING WITH GFLOWNETS](https://arxiv.org/abs/2302.05446))

Parameters
Return type

torchtyping.TensorType[0, float]

class gfn.gflownet.ModifiedDBGFlowNet

Bases: gfn.gflownet.base.PFBasedGFlowNet[gfn.containers.Transitions]

The Modified Detailed Balance GFlowNet. Only applicable to environments where all states are terminating.

See Bayesian Structure Learning with Generative Flow Networks https://arxiv.org/abs/2202.13903 for more details.

get_scores(transitions)

DAG-GFN-style detailed balance, when all states are connected to the sink.

Raises
  • ValueError – when backward transitions are supplied (not supported).

  • ValueError – when the computed scores contain inf.

Parameters

transitions (gfn.containers.Transitions) –

Return type

torchtyping.TensorType[n_trajectories, torch.float]

loss(env, transitions)

Calculates the modified detailed balance loss.

Parameters
Return type

torchtyping.TensorType[0, float]

to_training_samples(trajectories)
Parameters

trajectories (gfn.containers.Trajectories) –

Return type

gfn.containers.Transitions

class gfn.gflownet.PFBasedGFlowNet(pf, pb, on_policy=False)

Bases: GFlowNet

Base class for gflownets that explicitly uses \(P_F\).

Parameters
pf

GFNModule

pb

GFNModule

sample_trajectories(env, n_samples)

Sample a specific number of complete trajectories.

Parameters
  • env (gfn.env.Env) – the environment to sample trajectories from.

  • n_samples (int) – number of trajectories to be sampled.

Returns

sampled trajectories object.

Return type

Trajectories

class gfn.gflownet.SubTBGFlowNet(pf, pb, logF, on_policy=False, weighting='geometric_within', lamda=0.9, log_reward_clip_min=-12, forward_looking=False)

Bases: gfn.gflownet.base.TrajectoryBasedGFlowNet

GFlowNet for the Sub Trajectory Balance Loss.

This method is described in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782).

Parameters
logF

a LogStateFlowEstimator instance.

weighting

sub-trajectories weighting scheme. - “DB”: Considers all one-step transitions of each trajectory in the

batch and weighs them equally (regardless of the length of trajectory). Should be equivalent to DetailedBalance loss.

  • “ModifiedDB”: Considers all one-step transitions of each trajectory

    in the batch and weighs them inversely proportional to the trajectory length. This ensures that the loss is not dominated by long trajectories. Each trajectory contributes equally to the loss.

  • “TB”: Considers only the full trajectory. Should be equivalent to

    TrajectoryBalance loss.

  • “equal_within”: Each sub-trajectory of each trajectory is weighed

    equally within the trajectory. Then each trajectory is weighed equally within the batch.

  • “equal”: Each sub-trajectory of each trajectory is weighed equally

    within the set of all sub-trajectories.

  • “geometric_within”: Each sub-trajectory of each trajectory is weighed

    proportionally to (lamda ** len(sub_trajectory)), within each trajectory. THIS CORRESPONDS TO THE ONE IN THE PAPER.

  • “geometric”: Each sub-trajectory of each trajectory is weighed

    proportionally to (lamda ** len(sub_trajectory)), within the set of all sub-trajectories.

lamda

discount factor for longer trajectories.

log_reward_clip_min

minimum value for log rewards.

calculate_log_state_flows(env, trajectories, log_pf_trajectories)

Calculate log state flows and masks for sink and terminal states.

Parameters
Returns

Log state flows. full_mask: A boolean tensor representing full states.

Return type

log_state_flows

calculate_masks(log_state_flows, trajectories)

Calculate masks for sink and terminal states.

Parameters
Return type

Tuple[MaskTensor, MaskTensor, MaskTensor]

calculate_preds(log_pf_trajectories_cum, log_state_flows, i)

Calculate the predictions tensor for the current sub-trajectory length.

Parameters
  • log_pf_trajectories_cum (CumulativeLogProbsTensor) –

  • log_state_flows (LogStateFlowsTensor) –

  • i (int) –

Return type

PredictionsTensor

calculate_targets(trajectories, preds, log_pb_trajectories_cum, log_state_flows, is_terminal_mask, sink_states_mask, full_mask, i)

Calculate the targets tensor for the current sub-trajectory length.

Parameters
  • trajectories (gfn.containers.Trajectories) –

  • preds (PredictionsTensor) –

  • log_pb_trajectories_cum (CumulativeLogProbsTensor) –

  • log_state_flows (LogStateFlowsTensor) –

  • is_terminal_mask (MaskTensor) –

  • sink_states_mask (MaskTensor) –

  • full_mask (MaskTensor) –

  • i (int) –

Return type

TargetsTensor

cumulative_logprobs(trajectories, log_p_trajectories)

Calculates the cumulative log probabilities for all trajectories.

Parameters
  • trajectories (gfn.containers.Trajectories) – a batch of trajectories.

  • log_p_trajectories (LogTrajectoriesTensor) – log probabilities of each transition in each trajectory.

Return type

CumulativeLogProbsTensor

Returns: cumulative sum of log probabilities of each trajectory.

get_equal_contributions(trajectories)

Calculates contributions for the ‘equal’ weighting method.

Parameters

trajectories (gfn.containers.Trajectories) –

Return type

ContributionsTensor

get_equal_within_contributions(trajectories)

Calculates contributions for the ‘equal_within’ weighting method.

Parameters

trajectories (gfn.containers.Trajectories) –

Return type

ContributionsTensor

get_geometric_within_contributions(trajectories)

Calculates contributions for the ‘geometric_within’ weighting method.

Parameters

trajectories (gfn.containers.Trajectories) –

Return type

ContributionsTensor

get_modified_db_contributions(trajectories)

Calculates contributions for the ‘ModifiedDB’ weighting method.

Parameters

trajectories (gfn.containers.Trajectories) –

Return type

ContributionsTensor

get_scores(env, trajectories)

Scores all submitted trajectories.

Returns

  • A list of tensors, each of which representing the scores of all

    sub-trajectories of length k, for k in [1, …, trajectories.max_length], where the score of a sub-trajectory tau is \(log P_F(tau) + log F(tau_0) - log P_B(tau) - log F(tau_{-1})\). The shape of the k-th tensor is (trajectories.max_length - k + 1, trajectories.n_trajectories), k starting from 1.

  • A list of tensors representing what should be masked out in the each

    element of the first list, given that not all sub-trajectories of length k exist for each trajectory. The entries of those tensors are True if the corresponding sub-trajectory does not exist.

Parameters
Return type

Tuple[List[torchtyping.TensorType[0, float]], List[torchtyping.TensorType[0, float]]]

get_tb_contributions(trajectories, all_scores)

Calculates contributions for the ‘TB’ weighting method.

Parameters
Return type

ContributionsTensor

loss(env, trajectories)
Parameters
Return type

torchtyping.TensorType[0, float]

class gfn.gflownet.TBGFlowNet(pf, pb, on_policy=False, init_logZ=0.0, log_reward_clip_min=-12)

Bases: gfn.gflownet.base.TrajectoryBasedGFlowNet

Holds the logZ estimate for the Trajectory Balance loss.

\(\mathcal{O}_{PFZ} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1 = \mathbb{R}\) represents the possible values for logZ, and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.logit_PB is a fixed DiscretePBEstimator.

Parameters
logZ

a LogZEstimator instance.

log_reward_clip_min

minimal value to clamp the reward to.

loss(env, trajectories)

Trajectory balance loss.

The trajectory balance loss is described in 2.3 of [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259))

Raises

ValueError – if the loss is NaN.

Parameters
Return type

torchtyping.TensorType[0, float]

class gfn.gflownet.TrajectoryBasedGFlowNet(pf, pb, on_policy=False)

Bases: PFBasedGFlowNet[gfn.containers.Trajectories]

Base class for gflownets that explicitly uses \(P_F\).

Parameters
pf

GFNModule

pb

GFNModule

get_pfs_and_pbs(trajectories, fill_value=0.0)

Evaluates logprobs for each transition in each trajectory in the batch.

More specifically it evaluates \(\log P_F (s' \mid s)\) and \(\log P_B(s \mid s')\) for each transition in each trajectory in the batch.

Useful when the policy used to sample the trajectories is different from the one used to evaluate the loss. Otherwise we can use the logprobs directly from the trajectories.

Parameters
  • trajectories (gfn.containers.Trajectories) – Trajectories to evaluate.

  • fill_value (float) – Value to use for invalid states (i.e. \(s_f\) that is added to shorter trajectories).

Return type

Tuple[torchtyping.TensorType[max_length, n_trajectories, torch.float], torchtyping.TensorType[max_length, n_trajectories, torch.float]]

Returns: A tuple of float tensors of shape (max_length, n_trajectories) containing

the log_pf and log_pb for each action in each trajectory. The first one can be None.

Raises
  • ValueError – if the trajectories are backward.

  • AssertionError – when actions and states dimensions mismatch.

Parameters
Return type

Tuple[torchtyping.TensorType[max_length, n_trajectories, torch.float], torchtyping.TensorType[max_length, n_trajectories, torch.float]]

get_trajectories_scores(trajectories)

Given a batch of trajectories, calculate forward & backward policy scores.

Parameters

trajectories (gfn.containers.Trajectories) –

Return type

Tuple[torchtyping.TensorType[n_trajectories, torch.float], torchtyping.TensorType[n_trajectories, torch.float], torchtyping.TensorType[n_trajectories, torch.float]]

to_training_samples(trajectories)

Converts trajectories to training samples. The type depends on the GFlowNet.

Parameters

trajectories (gfn.containers.Trajectories) –

Return type

gfn.containers.Trajectories