gfn.losses.sub_trajectory_balance
Module Contents
Classes
Exactly the same as DBParametrization |
|
Abstract Base Class for all GFN Losses |
Attributes
- gfn.losses.sub_trajectory_balance.LogPTrajectoriesTensor
- gfn.losses.sub_trajectory_balance.LossTensor
- gfn.losses.sub_trajectory_balance.ScoresTensor
- class gfn.losses.sub_trajectory_balance.SubTBParametrization
Bases:
gfn.losses.base.PFBasedParametrizationExactly the same as DBParametrization
- logF :gfn.estimators.LogStateFlowEstimator
- class gfn.losses.sub_trajectory_balance.SubTrajectoryBalance(parametrization, log_reward_clip_min=-12, weighing='geometric', lamda=0.9, on_policy=False)
Bases:
gfn.losses.base.TrajectoryDecomposableLossAbstract Base Class for all GFN Losses
- Parameters
parametrization (SubTBParametrization) –
log_reward_clip_min (float) –
weighing (Literal[DB, ModifiedDB, TB, geometric, equal, geometric_within, equal_within]) –
lamda (float) –
on_policy (bool) –
- __call__(trajectories)
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
LossTensor
- cumulative_logprobs(trajectories, log_p_trajectories)
- Parameters
trajectories (gfn.containers.Trajectories) – trajectories
log_p_trajectories (LogPTrajectoriesTensor) – log probabilities of each transition in each trajectory
- Returns
cumulative sum of log probabilities of each trajectory
- Return type
LogPTrajectoriesTensor
- get_scores(trajectories)
Returns two elements: - A list of tensors, each of which representing the scores of all sub-trajectories of length k, for k in [1, …, trajectories.max_length].
where the score of a sub-trajectory tau is log P_F(tau) + log F(tau_0) - log P_B(tau) - log F(tau_{-1}). The shape of the k-th tensor is (trajectories.max_length - k + 1, trajectories.n_trajectories), k starting from 1.
- A list of tensors representing what should be masked out in the each element of the first list, given that not all sub-trajectories
of length k exist for each trajectory. The entries of those tensors are True if the corresponding sub-trajectory does not exist.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
Tuple[List[ScoresTensor], List[ScoresTensor]]