gfn.gflownet.sub_trajectory_balance
Module Contents
Classes
GFlowNet for the Sub Trajectory Balance Loss. |
Attributes
- gfn.gflownet.sub_trajectory_balance.ContributionsTensor
- gfn.gflownet.sub_trajectory_balance.CumulativeLogProbsTensor
- gfn.gflownet.sub_trajectory_balance.LogStateFlowsTensor
- gfn.gflownet.sub_trajectory_balance.LogTrajectoriesTensor
- gfn.gflownet.sub_trajectory_balance.MaskTensor
- gfn.gflownet.sub_trajectory_balance.PredictionsTensor
- class gfn.gflownet.sub_trajectory_balance.SubTBGFlowNet(pf, pb, logF, on_policy=False, weighting='geometric_within', lamda=0.9, log_reward_clip_min=-12, forward_looking=False)
Bases:
gfn.gflownet.base.TrajectoryBasedGFlowNetGFlowNet for the Sub Trajectory Balance Loss.
This method is described in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782).
- Parameters
pf (gfn.modules.GFNModule) –
pb (gfn.modules.GFNModule) –
logF (gfn.modules.ScalarEstimator) –
on_policy (bool) –
weighting (Literal[DB, ModifiedDB, TB, geometric, equal, geometric_within, equal_within]) –
lamda (float) –
log_reward_clip_min (float) –
forward_looking (bool) –
- logF
a LogStateFlowEstimator instance.
- weighting
sub-trajectories weighting scheme. - “DB”: Considers all one-step transitions of each trajectory in the
batch and weighs them equally (regardless of the length of trajectory). Should be equivalent to DetailedBalance loss.
- “ModifiedDB”: Considers all one-step transitions of each trajectory
in the batch and weighs them inversely proportional to the trajectory length. This ensures that the loss is not dominated by long trajectories. Each trajectory contributes equally to the loss.
- “TB”: Considers only the full trajectory. Should be equivalent to
TrajectoryBalance loss.
- “equal_within”: Each sub-trajectory of each trajectory is weighed
equally within the trajectory. Then each trajectory is weighed equally within the batch.
- “equal”: Each sub-trajectory of each trajectory is weighed equally
within the set of all sub-trajectories.
- “geometric_within”: Each sub-trajectory of each trajectory is weighed
proportionally to (lamda ** len(sub_trajectory)), within each trajectory. THIS CORRESPONDS TO THE ONE IN THE PAPER.
- “geometric”: Each sub-trajectory of each trajectory is weighed
proportionally to (lamda ** len(sub_trajectory)), within the set of all sub-trajectories.
- lamda
discount factor for longer trajectories.
- log_reward_clip_min
minimum value for log rewards.
- calculate_log_state_flows(env, trajectories, log_pf_trajectories)
Calculate log state flows and masks for sink and terminal states.
- Parameters
trajectories (gfn.containers.Trajectories) – The trajectories data.
env (gfn.env.Env) – The environment object.
log_pf_trajectories (LogTrajectoriesTensor) –
- Returns
Log state flows. full_mask: A boolean tensor representing full states.
- Return type
log_state_flows
- calculate_masks(log_state_flows, trajectories)
Calculate masks for sink and terminal states.
- Parameters
log_state_flows (LogStateFlowsTensor) –
trajectories (gfn.containers.Trajectories) –
- Return type
Tuple[MaskTensor, MaskTensor, MaskTensor]
- calculate_preds(log_pf_trajectories_cum, log_state_flows, i)
Calculate the predictions tensor for the current sub-trajectory length.
- Parameters
log_pf_trajectories_cum (CumulativeLogProbsTensor) –
log_state_flows (LogStateFlowsTensor) –
i (int) –
- Return type
PredictionsTensor
- calculate_targets(trajectories, preds, log_pb_trajectories_cum, log_state_flows, is_terminal_mask, sink_states_mask, full_mask, i)
Calculate the targets tensor for the current sub-trajectory length.
- Parameters
trajectories (gfn.containers.Trajectories) –
preds (PredictionsTensor) –
log_pb_trajectories_cum (CumulativeLogProbsTensor) –
log_state_flows (LogStateFlowsTensor) –
is_terminal_mask (MaskTensor) –
sink_states_mask (MaskTensor) –
full_mask (MaskTensor) –
i (int) –
- Return type
TargetsTensor
- cumulative_logprobs(trajectories, log_p_trajectories)
Calculates the cumulative log probabilities for all trajectories.
- Parameters
trajectories (gfn.containers.Trajectories) – a batch of trajectories.
log_p_trajectories (LogTrajectoriesTensor) – log probabilities of each transition in each trajectory.
- Return type
CumulativeLogProbsTensor
Returns: cumulative sum of log probabilities of each trajectory.
- get_equal_contributions(trajectories)
Calculates contributions for the ‘equal’ weighting method.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
ContributionsTensor
- get_equal_within_contributions(trajectories)
Calculates contributions for the ‘equal_within’ weighting method.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
ContributionsTensor
- get_geometric_within_contributions(trajectories)
Calculates contributions for the ‘geometric_within’ weighting method.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
ContributionsTensor
- get_modified_db_contributions(trajectories)
Calculates contributions for the ‘ModifiedDB’ weighting method.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
ContributionsTensor
- get_scores(env, trajectories)
Scores all submitted trajectories.
- Returns
- A list of tensors, each of which representing the scores of all
sub-trajectories of length k, for k in [1, …, trajectories.max_length], where the score of a sub-trajectory tau is \(log P_F(tau) + log F(tau_0) - log P_B(tau) - log F(tau_{-1})\). The shape of the k-th tensor is (trajectories.max_length - k + 1, trajectories.n_trajectories), k starting from 1.
- A list of tensors representing what should be masked out in the each
element of the first list, given that not all sub-trajectories of length k exist for each trajectory. The entries of those tensors are True if the corresponding sub-trajectory does not exist.
- Parameters
env (gfn.env.Env) –
trajectories (gfn.containers.Trajectories) –
- Return type
Tuple[List[torchtyping.TensorType[0, float]], List[torchtyping.TensorType[0, float]]]
- get_tb_contributions(trajectories, all_scores)
Calculates contributions for the ‘TB’ weighting method.
- Parameters
trajectories (gfn.containers.Trajectories) –
all_scores (torchtyping.TensorType) –
- Return type
ContributionsTensor
- loss(env, trajectories)
- Parameters
env (gfn.env.Env) –
trajectories (gfn.containers.Trajectories) –
- Return type
torchtyping.TensorType[0, float]
- gfn.gflownet.sub_trajectory_balance.TargetsTensor