gfn.losses
Submodules
Package Contents
Classes
Corresponds to $mathcal{O}_{PF} = mathcal{O}_1 times mathcal{O}_2 times mathcal{O}_3$, where |
|
Abstract Base Class for all GFN Losses |
|
Abstract Base Class for all GFN Losses |
|
$mathcal{O}_{edge}$ is the set of functions from the non-terminating edges |
|
Abstract Base Class for all GFN Losses |
|
Base class for parametrizations that explicitly uses $P_F$ |
|
Abstract Base Class for Flow Parametrizations, |
|
Abstract Base Class for all GFN Losses |
|
Exactly the same as DBParametrization |
|
Abstract Base Class for all GFN Losses |
|
$mathcal{O}_{PFZ} = mathcal{O}_1 times mathcal{O}_2 times mathcal{O}_3$, where |
|
Abstract Base Class for all GFN Losses |
- class gfn.losses.DBParametrization
Bases:
gfn.losses.base.PFBasedParametrizationCorresponds to \(\mathcal{O}_{PF} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1\) is the set of functions from the internal states (no \(s_f\)) to \(\mathbb{R}^+\) (which we parametrize with logs, to avoid the non-negativity constraint), and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.logit_PB is a fixed LogitPBEstimator. Useful for the Detailed Balance Loss.
- logF :gfn.estimators.LogStateFlowEstimator
- class gfn.losses.DetailedBalance(parametrization, on_policy=False)
Bases:
gfn.losses.base.EdgeDecomposableLossAbstract Base Class for all GFN Losses
- Parameters
parametrization (DBParametrization) –
on_policy (bool) –
- __call__(transitions)
- Parameters
transitions (gfn.containers.Transitions) –
- Return type
LossTensor
- get_modified_scores(transitions)
DAG-GFN-style detailed balance, for when all states are connected to the sink
- Parameters
transitions (gfn.containers.Transitions) –
- Return type
ScoresTensor
- get_scores(transitions)
- Parameters
transitions (gfn.containers.Transitions) –
- class gfn.losses.EdgeDecomposableLoss(parametrization)
Bases:
Loss,abc.ABCAbstract Base Class for all GFN Losses
- Parameters
parametrization (Parametrization) –
- abstract __call__(edges)
- Parameters
edges (gfn.containers.transitions.Transitions) –
- Return type
torchtyping.TensorType[0, float]
- class gfn.losses.FMParametrization
Bases:
gfn.losses.base.Parametrization\(\mathcal{O}_{edge}\) is the set of functions from the non-terminating edges to \(\mathbb{R}^+\). Which is equivalent to the set of functions from the internal nodes (i.e. without \(s_f\)) to \((\mathbb{R})^{n_actions}\), without the exit action (No need for positivity if we parametrize log-flows).
- logF :gfn.estimators.LogEdgeFlowEstimator
- Pi(env, n_samples=1000, **actions_sampler_kwargs)
- Parameters
env (gfn.envs.Env) –
n_samples (int) –
- Return type
- class gfn.losses.FlowMatching(parametrization, alpha=1.0)
Bases:
gfn.losses.base.StateDecomposableLoss- Parameters
parametrization (FMParametrization) –
- __call__(states_tuple)
- Parameters
states_tuple (Tuple[gfn.containers.states.States, gfn.containers.states.States]) –
- Return type
LossTensor
- flow_matching_loss(states)
Compute the FM for the given states, defined as the log-sum incoming flows minus log-sum outgoing flows. The states should not include s0. The batch shape should be (n_states,).
As of now, only discrete environments are handled.
- Parameters
states (gfn.containers.states.States) –
- Return type
ScoresTensor
- reward_matching_loss(terminating_states)
- Parameters
terminating_states (gfn.containers.states.States) –
- Return type
LossTensor
- class gfn.losses.LogPartitionVarianceLoss(parametrization, log_reward_clip_min=-12, on_policy=False)
Bases:
gfn.losses.base.TrajectoryDecomposableLoss- Parameters
parametrization (gfn.losses.base.PFBasedParametrization) –
log_reward_clip_min (float) –
on_policy (bool) –
- __call__(trajectories)
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
LossTensor
- class gfn.losses.Loss(parametrization)
Bases:
abc.ABCAbstract Base Class for all GFN Losses
- Parameters
parametrization (Parametrization) –
- abstract __call__(*args, **kwargs)
- Return type
torchtyping.TensorType[0, float]
- class gfn.losses.PFBasedParametrization
Bases:
Parametrization,abc.ABCBase class for parametrizations that explicitly uses \(P_F\)
- logit_PB :gfn.estimators.LogitPBEstimator
- logit_PF :gfn.estimators.LogitPFEstimator
- Pi(env, n_samples=1000, **actions_sampler_kwargs)
- Parameters
env (gfn.envs.Env) –
n_samples (int) –
- Return type
- class gfn.losses.Parametrization
Bases:
abc.ABCAbstract Base Class for Flow Parametrizations, as defined in Sec. 3 of GFlowNets Foundations. All attributes should be estimators, and should either have a GFNModule or attribute called module, or torch.Tensor attribute called tensor with requires_grad=True.
- property parameters: dict
Return a dictionary of all parameters of the parametrization. Note that there might be duplicate parameters (e.g. when two NNs share parameters), in which case the optimizer should take as input set(self.parameters.values()).
- Return type
dict
- P_T(env, n_samples, **kwargs)
- Parameters
env (gfn.envs.Env) –
n_samples (int) –
- Return type
gfn.distributions.TrajectoryBasedTerminatingStateDistribution
- abstract Pi(env, n_samples, **kwargs)
- Parameters
env (gfn.envs.Env) –
n_samples (int) –
- Return type
- load_state_dict(path)
- Parameters
path (str) –
- save_state_dict(path)
- Parameters
path (str) –
- class gfn.losses.StateDecomposableLoss(parametrization)
Bases:
Loss,abc.ABCAbstract Base Class for all GFN Losses
- Parameters
parametrization (Parametrization) –
- abstract __call__(states_tuple)
Unlike the GFlowNets Foundations paper, we allow more flexibility by passing a tuple of states, the first one being the internal states of the trajectories (i.e. non-terminal states), and the second one being the terminal states of the trajectories. If these two are not handled differently, then they should be concatenated together.
- Parameters
states_tuple (Tuple[gfn.containers.states.States, gfn.containers.states.States]) –
- Return type
torchtyping.TensorType[0, float]
- class gfn.losses.SubTBParametrization
Bases:
gfn.losses.base.PFBasedParametrizationExactly the same as DBParametrization
- logF :gfn.estimators.LogStateFlowEstimator
- class gfn.losses.SubTrajectoryBalance(parametrization, log_reward_clip_min=-12, weighing='geometric', lamda=0.9, on_policy=False)
Bases:
gfn.losses.base.TrajectoryDecomposableLossAbstract Base Class for all GFN Losses
- Parameters
parametrization (SubTBParametrization) –
log_reward_clip_min (float) –
weighing (Literal[DB, ModifiedDB, TB, geometric, equal, geometric_within, equal_within]) –
lamda (float) –
on_policy (bool) –
- __call__(trajectories)
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
LossTensor
- cumulative_logprobs(trajectories, log_p_trajectories)
- Parameters
trajectories (gfn.containers.Trajectories) – trajectories
log_p_trajectories (LogPTrajectoriesTensor) – log probabilities of each transition in each trajectory
- Returns
cumulative sum of log probabilities of each trajectory
- Return type
LogPTrajectoriesTensor
- get_scores(trajectories)
Returns two elements: - A list of tensors, each of which representing the scores of all sub-trajectories of length k, for k in [1, …, trajectories.max_length].
where the score of a sub-trajectory tau is log P_F(tau) + log F(tau_0) - log P_B(tau) - log F(tau_{-1}). The shape of the k-th tensor is (trajectories.max_length - k + 1, trajectories.n_trajectories), k starting from 1.
- A list of tensors representing what should be masked out in the each element of the first list, given that not all sub-trajectories
of length k exist for each trajectory. The entries of those tensors are True if the corresponding sub-trajectory does not exist.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
Tuple[List[ScoresTensor], List[ScoresTensor]]
- class gfn.losses.TBParametrization
Bases:
gfn.losses.base.PFBasedParametrization\(\mathcal{O}_{PFZ} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1 = \mathbb{R}\) represents the possible values for logZ, and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.logit_PB is a fixed LogitPBEstimator. Useful for the Trajectory Balance Loss.
- logZ :gfn.estimators.LogZEstimator
- class gfn.losses.TrajectoryBalance(parametrization, log_reward_clip_min=-12, on_policy=False)
Bases:
gfn.losses.base.TrajectoryDecomposableLoss- Parameters
parametrization (TBParametrization) –
log_reward_clip_min (float) –
on_policy (bool) –
- __call__(trajectories)
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
LossTensor
- class gfn.losses.TrajectoryDecomposableLoss(parametrization)
Bases:
Loss,abc.ABCAbstract Base Class for all GFN Losses
- Parameters
parametrization (Parametrization) –
- abstract __call__(trajectories)
- Parameters
trajectories (gfn.containers.trajectories.Trajectories) –
- Return type
torchtyping.TensorType[0, float]
- get_pfs_and_pbs(trajectories, fill_value=0.0, temperature=1.0, epsilon=0.0, no_pf=False)
Evaluate log_pf and log_pb for each action in each trajectory in the batch. This is useful when the policy used to sample the trajectories is different from the one used to evaluate the loss.
- Parameters
trajectories (Trajectories) – Trajectories to evaluate.
fill_value (float, optional) – Value to use for invalid states (i.e. s_f that is added to shorter trajectories). Defaults to 0.0.
action. (The next parameters correspond to how the actions_sampler evaluates each) –
temperature (float, optional) – Temperature to use for the softmax. Defaults to 1.0.
epsilon (float, optional) – Epsilon to use for the softmax. Defaults to 0.0.
no_pf (bool, optional) – Whether to evaluate log_pf as well. Defaults to False.
- Raises
ValueError – if the trajectories are backward.
- Returns
A tuple of float tensors of shape (max_length, n_trajectories) containing the log_pf and log_pb for each action in each trajectory. The first one can be None.
- Return type
Tuple[LogPTrajectoriesTensor | None, LogPTrajectoriesTensor]
- get_trajectories_scores(trajectories)
- Parameters
trajectories (gfn.containers.trajectories.Trajectories) –
- Return type
Tuple[ScoresTensor, ScoresTensor, ScoresTensor]