gfn.gflownet.detailed_balance

Module Contents

Classes

DBGFlowNet

The Detailed Balance GFlowNet.

ModifiedDBGFlowNet

The Modified Detailed Balance GFlowNet. Only applicable to environments where

class gfn.gflownet.detailed_balance.DBGFlowNet(pf, pb, logF, on_policy=False, forward_looking=False)

Bases: gfn.gflownet.base.PFBasedGFlowNet[gfn.containers.Transitions]

The Detailed Balance GFlowNet.

Corresponds to \(\mathcal{O}_{PF} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1\) is the set of functions from the internal states (no \(s_f\)) to \(\mathbb{R}^+\) (which we parametrize with logs, to avoid the non-negativity constraint), and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.logit_PB is a fixed DiscretePBEstimator.

Parameters
logF

a ScalarEstimator instance.

on_policy

boolean indicating whether we need to reevaluate the log probs.

forward_looking

whether to implement the forward looking GFN loss.

get_scores(env, transitions)

Given a batch of transitions, calculate the scores.

Parameters
Raises
  • ValueError – when supplied with backward transitions.

  • AssertionError – when log rewards of transitions are None.

Return type

Tuple[torchtyping.TensorType[n_transitions, float], torchtyping.TensorType[n_transitions, float], torchtyping.TensorType[n_transitions, float]]

loss(env, transitions)

Detailed balance loss.

The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).

Parameters
Return type

torchtyping.TensorType[0, float]

to_training_samples(trajectories)
Parameters

trajectories (gfn.containers.Trajectories) –

Return type

gfn.containers.Transitions

class gfn.gflownet.detailed_balance.ModifiedDBGFlowNet

Bases: gfn.gflownet.base.PFBasedGFlowNet[gfn.containers.Transitions]

The Modified Detailed Balance GFlowNet. Only applicable to environments where all states are terminating.

See Bayesian Structure Learning with Generative Flow Networks https://arxiv.org/abs/2202.13903 for more details.

get_scores(transitions)

DAG-GFN-style detailed balance, when all states are connected to the sink.

Raises
  • ValueError – when backward transitions are supplied (not supported).

  • ValueError – when the computed scores contain inf.

Parameters

transitions (gfn.containers.Transitions) –

Return type

torchtyping.TensorType[n_trajectories, torch.float]

loss(env, transitions)

Calculates the modified detailed balance loss.

Parameters
Return type

torchtyping.TensorType[0, float]

to_training_samples(trajectories)
Parameters

trajectories (gfn.containers.Trajectories) –

Return type

gfn.containers.Transitions