gfn.gflownet.detailed_balance
Module Contents
Classes
The Detailed Balance GFlowNet. |
|
The Modified Detailed Balance GFlowNet. Only applicable to environments where |
- class gfn.gflownet.detailed_balance.DBGFlowNet(pf, pb, logF, on_policy=False, forward_looking=False)
Bases:
gfn.gflownet.base.PFBasedGFlowNet[gfn.containers.Transitions]The Detailed Balance GFlowNet.
Corresponds to \(\mathcal{O}_{PF} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1\) is the set of functions from the internal states (no \(s_f\)) to \(\mathbb{R}^+\) (which we parametrize with logs, to avoid the non-negativity constraint), and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.logit_PB is a fixed DiscretePBEstimator.
- Parameters
pf (gfn.modules.GFNModule) –
pb (gfn.modules.GFNModule) –
logF (gfn.modules.ScalarEstimator) –
on_policy (bool) –
forward_looking (bool) –
- logF
a ScalarEstimator instance.
- on_policy
boolean indicating whether we need to reevaluate the log probs.
- forward_looking
whether to implement the forward looking GFN loss.
- get_scores(env, transitions)
Given a batch of transitions, calculate the scores.
- Parameters
transitions (gfn.containers.Transitions) – a batch of transitions.
env (gfn.env.Env) –
- Raises
ValueError – when supplied with backward transitions.
AssertionError – when log rewards of transitions are None.
- Return type
Tuple[torchtyping.TensorType[n_transitions, float], torchtyping.TensorType[n_transitions, float], torchtyping.TensorType[n_transitions, float]]
- loss(env, transitions)
Detailed balance loss.
The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).
- Parameters
env (gfn.env.Env) –
transitions (gfn.containers.Transitions) –
- Return type
torchtyping.TensorType[0, float]
- to_training_samples(trajectories)
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
- class gfn.gflownet.detailed_balance.ModifiedDBGFlowNet
Bases:
gfn.gflownet.base.PFBasedGFlowNet[gfn.containers.Transitions]The Modified Detailed Balance GFlowNet. Only applicable to environments where all states are terminating.
See Bayesian Structure Learning with Generative Flow Networks https://arxiv.org/abs/2202.13903 for more details.
- get_scores(transitions)
DAG-GFN-style detailed balance, when all states are connected to the sink.
- Raises
ValueError – when backward transitions are supplied (not supported).
ValueError – when the computed scores contain inf.
- Parameters
transitions (gfn.containers.Transitions) –
- Return type
torchtyping.TensorType[n_trajectories, torch.float]
- loss(env, transitions)
Calculates the modified detailed balance loss.
- Parameters
env (gfn.env.Env) –
transitions (gfn.containers.Transitions) –
- Return type
torchtyping.TensorType[0, float]
- to_training_samples(trajectories)
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type