gfn.losses.sub_trajectory_balance

Module Contents

Classes

SubTBParametrization

Exactly the same as DBParametrization

SubTrajectoryBalance

Abstract Base Class for all GFN Losses

Attributes

LogPTrajectoriesTensor

LossTensor

ScoresTensor

gfn.losses.sub_trajectory_balance.LogPTrajectoriesTensor
gfn.losses.sub_trajectory_balance.LossTensor
gfn.losses.sub_trajectory_balance.ScoresTensor
class gfn.losses.sub_trajectory_balance.SubTBParametrization

Bases: gfn.losses.base.PFBasedParametrization

Exactly the same as DBParametrization

logF :gfn.estimators.LogStateFlowEstimator
class gfn.losses.sub_trajectory_balance.SubTrajectoryBalance(parametrization, log_reward_clip_min=-12, weighing='geometric', lamda=0.9, on_policy=False)

Bases: gfn.losses.base.TrajectoryDecomposableLoss

Abstract Base Class for all GFN Losses

Parameters
  • parametrization (SubTBParametrization) –

  • log_reward_clip_min (float) –

  • weighing (Literal[DB, ModifiedDB, TB, geometric, equal, geometric_within, equal_within]) –

  • lamda (float) –

  • on_policy (bool) –

__call__(trajectories)
Parameters

trajectories (gfn.containers.Trajectories) –

Return type

LossTensor

cumulative_logprobs(trajectories, log_p_trajectories)
Parameters
  • trajectories (gfn.containers.Trajectories) – trajectories

  • log_p_trajectories (LogPTrajectoriesTensor) – log probabilities of each transition in each trajectory

Returns

cumulative sum of log probabilities of each trajectory

Return type

LogPTrajectoriesTensor

get_scores(trajectories)

Returns two elements: - A list of tensors, each of which representing the scores of all sub-trajectories of length k, for k in [1, …, trajectories.max_length].

where the score of a sub-trajectory tau is log P_F(tau) + log F(tau_0) - log P_B(tau) - log F(tau_{-1}). The shape of the k-th tensor is (trajectories.max_length - k + 1, trajectories.n_trajectories), k starting from 1.

  • A list of tensors representing what should be masked out in the each element of the first list, given that not all sub-trajectories

    of length k exist for each trajectory. The entries of those tensors are True if the corresponding sub-trajectory does not exist.

Parameters

trajectories (gfn.containers.Trajectories) –

Return type

Tuple[List[ScoresTensor], List[ScoresTensor]]