gfn.losses.detailed_balance
Module Contents
Classes
Corresponds to $mathcal{O}_{PF} = mathcal{O}_1 times mathcal{O}_2 times mathcal{O}_3$, where |
|
Abstract Base Class for all GFN Losses |
Attributes
- class gfn.losses.detailed_balance.DBParametrization
Bases:
gfn.losses.base.PFBasedParametrizationCorresponds to \(\mathcal{O}_{PF} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1\) is the set of functions from the internal states (no \(s_f\)) to \(\mathbb{R}^+\) (which we parametrize with logs, to avoid the non-negativity constraint), and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.logit_PB is a fixed LogitPBEstimator. Useful for the Detailed Balance Loss.
- logF :gfn.estimators.LogStateFlowEstimator
- class gfn.losses.detailed_balance.DetailedBalance(parametrization, on_policy=False)
Bases:
gfn.losses.base.EdgeDecomposableLossAbstract Base Class for all GFN Losses
- Parameters
parametrization (DBParametrization) –
on_policy (bool) –
- __call__(transitions)
- Parameters
transitions (gfn.containers.Transitions) –
- Return type
LossTensor
- get_modified_scores(transitions)
DAG-GFN-style detailed balance, for when all states are connected to the sink
- Parameters
transitions (gfn.containers.Transitions) –
- Return type
ScoresTensor
- get_scores(transitions)
- Parameters
transitions (gfn.containers.Transitions) –
- gfn.losses.detailed_balance.LossTensor
- gfn.losses.detailed_balance.ScoresTensor