gfn.gflownet.trajectory_balance
Implementations of the [Trajectory Balance loss](https://arxiv.org/abs/2201.13259) and the [Log Partition Variance loss](https://arxiv.org/abs/2302.05446).
Module Contents
Classes
Dataclass which holds the logZ estimate for the Log Partition Variance loss. |
|
Holds the logZ estimate for the Trajectory Balance loss. |
- class gfn.gflownet.trajectory_balance.LogPartitionVarianceGFlowNet(pf, pb, on_policy=False, log_reward_clip_min=-12)
Bases:
gfn.gflownet.base.TrajectoryBasedGFlowNetDataclass which holds the logZ estimate for the Log Partition Variance loss.
- log_reward_clip_min
minimal value to clamp the reward to.
- Raises
ValueError – if the loss is NaN.
- Parameters
pf (gfn.modules.GFNModule) –
pb (gfn.modules.GFNModule) –
on_policy (bool) –
log_reward_clip_min (float) –
- loss(env, trajectories)
Log Partition Variance loss.
This method is described in section 3.2 of [ROBUST SCHEDULING WITH GFLOWNETS](https://arxiv.org/abs/2302.05446))
- Parameters
env (gfn.env.Env) –
trajectories (gfn.containers.Trajectories) –
- Return type
torchtyping.TensorType[0, float]
- class gfn.gflownet.trajectory_balance.TBGFlowNet(pf, pb, on_policy=False, init_logZ=0.0, log_reward_clip_min=-12)
Bases:
gfn.gflownet.base.TrajectoryBasedGFlowNetHolds the logZ estimate for the Trajectory Balance loss.
\(\mathcal{O}_{PFZ} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1 = \mathbb{R}\) represents the possible values for logZ, and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.logit_PB is a fixed DiscretePBEstimator.
- Parameters
pf (gfn.modules.GFNModule) –
pb (gfn.modules.GFNModule) –
on_policy (bool) –
init_logZ (float) –
log_reward_clip_min (float) –
- logZ
a LogZEstimator instance.
- log_reward_clip_min
minimal value to clamp the reward to.
- loss(env, trajectories)
Trajectory balance loss.
The trajectory balance loss is described in 2.3 of [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259))
- Raises
ValueError – if the loss is NaN.
- Parameters
env (gfn.env.Env) –
trajectories (gfn.containers.Trajectories) –
- Return type
torchtyping.TensorType[0, float]