gfn.gflownet
Submodules
Package Contents
Classes
The Detailed Balance GFlowNet. |
|
Flow Matching GFlowNet, with edge flow estimator. |
|
Abstract Base Class for GFlowNets. |
|
Dataclass which holds the logZ estimate for the Log Partition Variance loss. |
|
The Modified Detailed Balance GFlowNet. Only applicable to environments where |
|
Base class for gflownets that explicitly uses $P_F$. |
|
GFlowNet for the Sub Trajectory Balance Loss. |
|
Holds the logZ estimate for the Trajectory Balance loss. |
|
Base class for gflownets that explicitly uses $P_F$. |
- class gfn.gflownet.DBGFlowNet(pf, pb, logF, on_policy=False, forward_looking=False)
Bases:
gfn.gflownet.base.PFBasedGFlowNet[gfn.containers.Transitions]The Detailed Balance GFlowNet.
Corresponds to \(\mathcal{O}_{PF} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1\) is the set of functions from the internal states (no \(s_f\)) to \(\mathbb{R}^+\) (which we parametrize with logs, to avoid the non-negativity constraint), and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.logit_PB is a fixed DiscretePBEstimator.
- Parameters
pf (gfn.modules.GFNModule) –
pb (gfn.modules.GFNModule) –
logF (gfn.modules.ScalarEstimator) –
on_policy (bool) –
forward_looking (bool) –
- logF
a ScalarEstimator instance.
- on_policy
boolean indicating whether we need to reevaluate the log probs.
- forward_looking
whether to implement the forward looking GFN loss.
- get_scores(env, transitions)
Given a batch of transitions, calculate the scores.
- Parameters
transitions (gfn.containers.Transitions) – a batch of transitions.
env (gfn.env.Env) –
- Raises
ValueError – when supplied with backward transitions.
AssertionError – when log rewards of transitions are None.
- Return type
Tuple[torchtyping.TensorType[n_transitions, float], torchtyping.TensorType[n_transitions, float], torchtyping.TensorType[n_transitions, float]]
- loss(env, transitions)
Detailed balance loss.
The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).
- Parameters
env (gfn.env.Env) –
transitions (gfn.containers.Transitions) –
- Return type
torchtyping.TensorType[0, float]
- to_training_samples(trajectories)
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
- class gfn.gflownet.FMGFlowNet(logF, alpha=1.0)
Bases:
gfn.gflownet.base.GFlowNet[Tuple[gfn.states.DiscreteStates,gfn.states.DiscreteStates]]Flow Matching GFlowNet, with edge flow estimator.
\(\mathcal{O}_{edge}\) is the set of functions from the non-terminating edges to \(\mathbb{R}^+\). Which is equivalent to the set of functions from the internal nodes (i.e. without \(s_f\)) to \((\mathbb{R})^{n_actions}\), without the exit action (No need for positivity if we parametrize log-flows).
The loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).
- Parameters
logF (gfn.modules.DiscretePolicyEstimator) –
alpha (float) –
- logF
LogEdgeFlowEstimator
- alpha
weight for the reward matching loss.
- flow_matching_loss(env, states)
Computes the FM for the provided states.
The Flow Matching loss is defined as the log-sum incoming flows minus log-sum outgoing flows. The states should not include \(s_0\). The batch shape should be (n_states,). As of now, only discrete environments are handled.
- Raises
AssertionError – If the batch shape is not linear.
AssertionError – If any state is at \(s_0\).
- Parameters
env (gfn.env.Env) –
states (gfn.states.DiscreteStates) –
- Return type
torchtyping.TensorType[n_trajectories, torch.float]
- loss(env, states_tuple)
Given a batch of non-terminal and terminal states, compute a loss.
Unlike the GFlowNets Foundations paper, we allow more flexibility by passing a tuple of states, the first one being the internal states of the trajectories (i.e. non-terminal states), and the second one being the terminal states of the trajectories.
- Parameters
env (gfn.env.Env) –
states_tuple (Tuple[gfn.states.DiscreteStates, gfn.states.DiscreteStates]) –
- Return type
torchtyping.TensorType[0, float]
- reward_matching_loss(env, terminating_states)
Calculates the reward matching loss from the terminating states.
- Parameters
env (gfn.env.Env) –
terminating_states (gfn.states.DiscreteStates) –
- Return type
torchtyping.TensorType[0, float]
- sample_trajectories(env, n_samples=1000)
Sample a specific number of complete trajectories.
- Parameters
env (gfn.env.Env) – the environment to sample trajectories from.
n_samples (int) – number of trajectories to be sampled.
- Returns
sampled trajectories object.
- Return type
- to_training_samples(trajectories)
Converts a batch of trajectories into a batch of training samples.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
- class gfn.gflownet.GFlowNet
Bases:
abc.ABC,torch.nn.Module,Generic[TrainingSampleType]Abstract Base Class for GFlowNets.
A formal definition of GFlowNets is given in Sec. 3 of [GFlowNet Foundations](https://arxiv.org/pdf/2111.09266).
- abstract loss(env, training_objects)
Computes the loss given the training objects.
- Parameters
env (gfn.env.Env) –
- sample_terminating_states(env, n_samples)
Rolls out the parametrization’s policy and returns the terminating states.
- Parameters
env (gfn.env.Env) – the environment to sample terminating states from.
n_samples (int) – number of terminating states to be sampled.
- Returns
sampled terminating states object.
- Return type
- abstract sample_trajectories(env, n_samples)
Sample a specific number of complete trajectories.
- Parameters
env (gfn.env.Env) – the environment to sample trajectories from.
n_samples (int) – number of trajectories to be sampled.
- Returns
sampled trajectories object.
- Return type
- abstract to_training_samples(trajectories)
Converts trajectories to training samples. The type depends on the GFlowNet.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
TrainingSampleType
- class gfn.gflownet.LogPartitionVarianceGFlowNet(pf, pb, on_policy=False, log_reward_clip_min=-12)
Bases:
gfn.gflownet.base.TrajectoryBasedGFlowNetDataclass which holds the logZ estimate for the Log Partition Variance loss.
- log_reward_clip_min
minimal value to clamp the reward to.
- Raises
ValueError – if the loss is NaN.
- Parameters
pf (gfn.modules.GFNModule) –
pb (gfn.modules.GFNModule) –
on_policy (bool) –
log_reward_clip_min (float) –
- loss(env, trajectories)
Log Partition Variance loss.
This method is described in section 3.2 of [ROBUST SCHEDULING WITH GFLOWNETS](https://arxiv.org/abs/2302.05446))
- Parameters
env (gfn.env.Env) –
trajectories (gfn.containers.Trajectories) –
- Return type
torchtyping.TensorType[0, float]
- class gfn.gflownet.ModifiedDBGFlowNet
Bases:
gfn.gflownet.base.PFBasedGFlowNet[gfn.containers.Transitions]The Modified Detailed Balance GFlowNet. Only applicable to environments where all states are terminating.
See Bayesian Structure Learning with Generative Flow Networks https://arxiv.org/abs/2202.13903 for more details.
- get_scores(transitions)
DAG-GFN-style detailed balance, when all states are connected to the sink.
- Raises
ValueError – when backward transitions are supplied (not supported).
ValueError – when the computed scores contain inf.
- Parameters
transitions (gfn.containers.Transitions) –
- Return type
torchtyping.TensorType[n_trajectories, torch.float]
- loss(env, transitions)
Calculates the modified detailed balance loss.
- Parameters
env (gfn.env.Env) –
transitions (gfn.containers.Transitions) –
- Return type
torchtyping.TensorType[0, float]
- to_training_samples(trajectories)
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
- class gfn.gflownet.PFBasedGFlowNet(pf, pb, on_policy=False)
Bases:
GFlowNetBase class for gflownets that explicitly uses \(P_F\).
- Parameters
pf (gfn.modules.GFNModule) –
pb (gfn.modules.GFNModule) –
on_policy (bool) –
- pf
GFNModule
- pb
GFNModule
- sample_trajectories(env, n_samples)
Sample a specific number of complete trajectories.
- Parameters
env (gfn.env.Env) – the environment to sample trajectories from.
n_samples (int) – number of trajectories to be sampled.
- Returns
sampled trajectories object.
- Return type
- class gfn.gflownet.SubTBGFlowNet(pf, pb, logF, on_policy=False, weighting='geometric_within', lamda=0.9, log_reward_clip_min=-12, forward_looking=False)
Bases:
gfn.gflownet.base.TrajectoryBasedGFlowNetGFlowNet for the Sub Trajectory Balance Loss.
This method is described in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782).
- Parameters
pf (gfn.modules.GFNModule) –
pb (gfn.modules.GFNModule) –
logF (gfn.modules.ScalarEstimator) –
on_policy (bool) –
weighting (Literal[DB, ModifiedDB, TB, geometric, equal, geometric_within, equal_within]) –
lamda (float) –
log_reward_clip_min (float) –
forward_looking (bool) –
- logF
a LogStateFlowEstimator instance.
- weighting
sub-trajectories weighting scheme. - “DB”: Considers all one-step transitions of each trajectory in the
batch and weighs them equally (regardless of the length of trajectory). Should be equivalent to DetailedBalance loss.
- “ModifiedDB”: Considers all one-step transitions of each trajectory
in the batch and weighs them inversely proportional to the trajectory length. This ensures that the loss is not dominated by long trajectories. Each trajectory contributes equally to the loss.
- “TB”: Considers only the full trajectory. Should be equivalent to
TrajectoryBalance loss.
- “equal_within”: Each sub-trajectory of each trajectory is weighed
equally within the trajectory. Then each trajectory is weighed equally within the batch.
- “equal”: Each sub-trajectory of each trajectory is weighed equally
within the set of all sub-trajectories.
- “geometric_within”: Each sub-trajectory of each trajectory is weighed
proportionally to (lamda ** len(sub_trajectory)), within each trajectory. THIS CORRESPONDS TO THE ONE IN THE PAPER.
- “geometric”: Each sub-trajectory of each trajectory is weighed
proportionally to (lamda ** len(sub_trajectory)), within the set of all sub-trajectories.
- lamda
discount factor for longer trajectories.
- log_reward_clip_min
minimum value for log rewards.
- calculate_log_state_flows(env, trajectories, log_pf_trajectories)
Calculate log state flows and masks for sink and terminal states.
- Parameters
trajectories (gfn.containers.Trajectories) – The trajectories data.
env (gfn.env.Env) – The environment object.
log_pf_trajectories (LogTrajectoriesTensor) –
- Returns
Log state flows. full_mask: A boolean tensor representing full states.
- Return type
log_state_flows
- calculate_masks(log_state_flows, trajectories)
Calculate masks for sink and terminal states.
- Parameters
log_state_flows (LogStateFlowsTensor) –
trajectories (gfn.containers.Trajectories) –
- Return type
Tuple[MaskTensor, MaskTensor, MaskTensor]
- calculate_preds(log_pf_trajectories_cum, log_state_flows, i)
Calculate the predictions tensor for the current sub-trajectory length.
- Parameters
log_pf_trajectories_cum (CumulativeLogProbsTensor) –
log_state_flows (LogStateFlowsTensor) –
i (int) –
- Return type
PredictionsTensor
- calculate_targets(trajectories, preds, log_pb_trajectories_cum, log_state_flows, is_terminal_mask, sink_states_mask, full_mask, i)
Calculate the targets tensor for the current sub-trajectory length.
- Parameters
trajectories (gfn.containers.Trajectories) –
preds (PredictionsTensor) –
log_pb_trajectories_cum (CumulativeLogProbsTensor) –
log_state_flows (LogStateFlowsTensor) –
is_terminal_mask (MaskTensor) –
sink_states_mask (MaskTensor) –
full_mask (MaskTensor) –
i (int) –
- Return type
TargetsTensor
- cumulative_logprobs(trajectories, log_p_trajectories)
Calculates the cumulative log probabilities for all trajectories.
- Parameters
trajectories (gfn.containers.Trajectories) – a batch of trajectories.
log_p_trajectories (LogTrajectoriesTensor) – log probabilities of each transition in each trajectory.
- Return type
CumulativeLogProbsTensor
Returns: cumulative sum of log probabilities of each trajectory.
- get_equal_contributions(trajectories)
Calculates contributions for the ‘equal’ weighting method.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
ContributionsTensor
- get_equal_within_contributions(trajectories)
Calculates contributions for the ‘equal_within’ weighting method.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
ContributionsTensor
- get_geometric_within_contributions(trajectories)
Calculates contributions for the ‘geometric_within’ weighting method.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
ContributionsTensor
- get_modified_db_contributions(trajectories)
Calculates contributions for the ‘ModifiedDB’ weighting method.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
ContributionsTensor
- get_scores(env, trajectories)
Scores all submitted trajectories.
- Returns
- A list of tensors, each of which representing the scores of all
sub-trajectories of length k, for k in [1, …, trajectories.max_length], where the score of a sub-trajectory tau is \(log P_F(tau) + log F(tau_0) - log P_B(tau) - log F(tau_{-1})\). The shape of the k-th tensor is (trajectories.max_length - k + 1, trajectories.n_trajectories), k starting from 1.
- A list of tensors representing what should be masked out in the each
element of the first list, given that not all sub-trajectories of length k exist for each trajectory. The entries of those tensors are True if the corresponding sub-trajectory does not exist.
- Parameters
env (gfn.env.Env) –
trajectories (gfn.containers.Trajectories) –
- Return type
Tuple[List[torchtyping.TensorType[0, float]], List[torchtyping.TensorType[0, float]]]
- get_tb_contributions(trajectories, all_scores)
Calculates contributions for the ‘TB’ weighting method.
- Parameters
trajectories (gfn.containers.Trajectories) –
all_scores (torchtyping.TensorType) –
- Return type
ContributionsTensor
- loss(env, trajectories)
- Parameters
env (gfn.env.Env) –
trajectories (gfn.containers.Trajectories) –
- Return type
torchtyping.TensorType[0, float]
- class gfn.gflownet.TBGFlowNet(pf, pb, on_policy=False, init_logZ=0.0, log_reward_clip_min=-12)
Bases:
gfn.gflownet.base.TrajectoryBasedGFlowNetHolds the logZ estimate for the Trajectory Balance loss.
\(\mathcal{O}_{PFZ} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1 = \mathbb{R}\) represents the possible values for logZ, and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.logit_PB is a fixed DiscretePBEstimator.
- Parameters
pf (gfn.modules.GFNModule) –
pb (gfn.modules.GFNModule) –
on_policy (bool) –
init_logZ (float) –
log_reward_clip_min (float) –
- logZ
a LogZEstimator instance.
- log_reward_clip_min
minimal value to clamp the reward to.
- loss(env, trajectories)
Trajectory balance loss.
The trajectory balance loss is described in 2.3 of [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259))
- Raises
ValueError – if the loss is NaN.
- Parameters
env (gfn.env.Env) –
trajectories (gfn.containers.Trajectories) –
- Return type
torchtyping.TensorType[0, float]
- class gfn.gflownet.TrajectoryBasedGFlowNet(pf, pb, on_policy=False)
Bases:
PFBasedGFlowNet[gfn.containers.Trajectories]Base class for gflownets that explicitly uses \(P_F\).
- Parameters
pf (gfn.modules.GFNModule) –
pb (gfn.modules.GFNModule) –
on_policy (bool) –
- pf
GFNModule
- pb
GFNModule
- get_pfs_and_pbs(trajectories, fill_value=0.0)
Evaluates logprobs for each transition in each trajectory in the batch.
More specifically it evaluates \(\log P_F (s' \mid s)\) and \(\log P_B(s \mid s')\) for each transition in each trajectory in the batch.
Useful when the policy used to sample the trajectories is different from the one used to evaluate the loss. Otherwise we can use the logprobs directly from the trajectories.
- Parameters
trajectories (gfn.containers.Trajectories) – Trajectories to evaluate.
fill_value (float) – Value to use for invalid states (i.e. \(s_f\) that is added to shorter trajectories).
- Return type
Tuple[torchtyping.TensorType[max_length, n_trajectories, torch.float], torchtyping.TensorType[max_length, n_trajectories, torch.float]]
- Returns: A tuple of float tensors of shape (max_length, n_trajectories) containing
the log_pf and log_pb for each action in each trajectory. The first one can be None.
- Raises
ValueError – if the trajectories are backward.
AssertionError – when actions and states dimensions mismatch.
- Parameters
trajectories (gfn.containers.Trajectories) –
fill_value (float) –
- Return type
Tuple[torchtyping.TensorType[max_length, n_trajectories, torch.float], torchtyping.TensorType[max_length, n_trajectories, torch.float]]
- get_trajectories_scores(trajectories)
Given a batch of trajectories, calculate forward & backward policy scores.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
Tuple[torchtyping.TensorType[n_trajectories, torch.float], torchtyping.TensorType[n_trajectories, torch.float], torchtyping.TensorType[n_trajectories, torch.float]]
- to_training_samples(trajectories)
Converts trajectories to training samples. The type depends on the GFlowNet.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type