gfn.losses.trajectory_balance
Implementations of the [Trajectory Balance loss](https://arxiv.org/abs/2201.13259) and the [Log Partition Variance loss](https://arxiv.org/abs/2302.05446).
Module Contents
Classes
$mathcal{O}_{PFZ} = mathcal{O}_1 times mathcal{O}_2 times mathcal{O}_3$, where |
|
Attributes
- class gfn.losses.trajectory_balance.LogPartitionVarianceLoss(parametrization, log_reward_clip_min=-12, on_policy=False)
Bases:
gfn.losses.base.TrajectoryDecomposableLoss- Parameters
parametrization (gfn.losses.base.PFBasedParametrization) –
log_reward_clip_min (float) –
on_policy (bool) –
- __call__(trajectories)
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
LossTensor
- gfn.losses.trajectory_balance.LossTensor
- gfn.losses.trajectory_balance.ScoresTensor
- class gfn.losses.trajectory_balance.TBParametrization
Bases:
gfn.losses.base.PFBasedParametrization\(\mathcal{O}_{PFZ} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1 = \mathbb{R}\) represents the possible values for logZ, and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.logit_PB is a fixed LogitPBEstimator. Useful for the Trajectory Balance Loss.
- logZ :gfn.estimators.LogZEstimator
- class gfn.losses.trajectory_balance.TrajectoryBalance(parametrization, log_reward_clip_min=-12, on_policy=False)
Bases:
gfn.losses.base.TrajectoryDecomposableLoss- Parameters
parametrization (TBParametrization) –
log_reward_clip_min (float) –
on_policy (bool) –
- __call__(trajectories)
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
LossTensor