gfn.losses.detailed_balance

Module Contents

Classes

DBParametrization

Corresponds to $mathcal{O}_{PF} = mathcal{O}_1 times mathcal{O}_2 times mathcal{O}_3$, where

DetailedBalance

Abstract Base Class for all GFN Losses

Attributes

LossTensor

ScoresTensor

class gfn.losses.detailed_balance.DBParametrization

Bases: gfn.losses.base.PFBasedParametrization

Corresponds to \(\mathcal{O}_{PF} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1\) is the set of functions from the internal states (no \(s_f\)) to \(\mathbb{R}^+\) (which we parametrize with logs, to avoid the non-negativity constraint), and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.logit_PB is a fixed LogitPBEstimator. Useful for the Detailed Balance Loss.

logF :gfn.estimators.LogStateFlowEstimator
class gfn.losses.detailed_balance.DetailedBalance(parametrization, on_policy=False)

Bases: gfn.losses.base.EdgeDecomposableLoss

Abstract Base Class for all GFN Losses

Parameters
__call__(transitions)
Parameters

transitions (gfn.containers.Transitions) –

Return type

LossTensor

get_modified_scores(transitions)

DAG-GFN-style detailed balance, for when all states are connected to the sink

Parameters

transitions (gfn.containers.Transitions) –

Return type

ScoresTensor

get_scores(transitions)
Parameters

transitions (gfn.containers.Transitions) –

gfn.losses.detailed_balance.LossTensor
gfn.losses.detailed_balance.ScoresTensor