gfn.losses.trajectory_balance

Implementations of the [Trajectory Balance loss](https://arxiv.org/abs/2201.13259) and the [Log Partition Variance loss](https://arxiv.org/abs/2302.05446).

Module Contents

Classes

LogPartitionVarianceLoss

TBParametrization

$mathcal{O}_{PFZ} = mathcal{O}_1 times mathcal{O}_2 times mathcal{O}_3$, where

TrajectoryBalance

Attributes

LossTensor

ScoresTensor

class gfn.losses.trajectory_balance.LogPartitionVarianceLoss(parametrization, log_reward_clip_min=-12, on_policy=False)

Bases: gfn.losses.base.TrajectoryDecomposableLoss

Parameters
__call__(trajectories)
Parameters

trajectories (gfn.containers.Trajectories) –

Return type

LossTensor

gfn.losses.trajectory_balance.LossTensor
gfn.losses.trajectory_balance.ScoresTensor
class gfn.losses.trajectory_balance.TBParametrization

Bases: gfn.losses.base.PFBasedParametrization

\(\mathcal{O}_{PFZ} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1 = \mathbb{R}\) represents the possible values for logZ, and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.logit_PB is a fixed LogitPBEstimator. Useful for the Trajectory Balance Loss.

logZ :gfn.estimators.LogZEstimator
class gfn.losses.trajectory_balance.TrajectoryBalance(parametrization, log_reward_clip_min=-12, on_policy=False)

Bases: gfn.losses.base.TrajectoryDecomposableLoss

Parameters
  • parametrization (TBParametrization) –

  • log_reward_clip_min (float) –

  • on_policy (bool) –

__call__(trajectories)
Parameters

trajectories (gfn.containers.Trajectories) –

Return type

LossTensor