gfn.gflownet.trajectory_balance

Implementations of the [Trajectory Balance loss](https://arxiv.org/abs/2201.13259) and the [Log Partition Variance loss](https://arxiv.org/abs/2302.05446).

Module Contents

Classes

LogPartitionVarianceGFlowNet

Dataclass which holds the logZ estimate for the Log Partition Variance loss.

TBGFlowNet

Holds the logZ estimate for the Trajectory Balance loss.

class gfn.gflownet.trajectory_balance.LogPartitionVarianceGFlowNet(pf, pb, on_policy=False, log_reward_clip_min=-12)

Bases: gfn.gflownet.base.TrajectoryBasedGFlowNet

Dataclass which holds the logZ estimate for the Log Partition Variance loss.

log_reward_clip_min

minimal value to clamp the reward to.

Raises

ValueError – if the loss is NaN.

Parameters
loss(env, trajectories)

Log Partition Variance loss.

This method is described in section 3.2 of [ROBUST SCHEDULING WITH GFLOWNETS](https://arxiv.org/abs/2302.05446))

Parameters
Return type

torchtyping.TensorType[0, float]

class gfn.gflownet.trajectory_balance.TBGFlowNet(pf, pb, on_policy=False, init_logZ=0.0, log_reward_clip_min=-12)

Bases: gfn.gflownet.base.TrajectoryBasedGFlowNet

Holds the logZ estimate for the Trajectory Balance loss.

\(\mathcal{O}_{PFZ} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1 = \mathbb{R}\) represents the possible values for logZ, and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.logit_PB is a fixed DiscretePBEstimator.

Parameters
logZ

a LogZEstimator instance.

log_reward_clip_min

minimal value to clamp the reward to.

loss(env, trajectories)

Trajectory balance loss.

The trajectory balance loss is described in 2.3 of [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259))

Raises

ValueError – if the loss is NaN.

Parameters
Return type

torchtyping.TensorType[0, float]