gfn.losses

Submodules

Package Contents

Classes

DBParametrization

Corresponds to $mathcal{O}_{PF} = mathcal{O}_1 times mathcal{O}_2 times mathcal{O}_3$, where

DetailedBalance

Abstract Base Class for all GFN Losses

EdgeDecomposableLoss

Abstract Base Class for all GFN Losses

FMParametrization

$mathcal{O}_{edge}$ is the set of functions from the non-terminating edges

FlowMatching

LogPartitionVarianceLoss

Loss

Abstract Base Class for all GFN Losses

PFBasedParametrization

Base class for parametrizations that explicitly uses $P_F$

Parametrization

Abstract Base Class for Flow Parametrizations,

StateDecomposableLoss

Abstract Base Class for all GFN Losses

SubTBParametrization

Exactly the same as DBParametrization

SubTrajectoryBalance

Abstract Base Class for all GFN Losses

TBParametrization

$mathcal{O}_{PFZ} = mathcal{O}_1 times mathcal{O}_2 times mathcal{O}_3$, where

TrajectoryBalance

TrajectoryDecomposableLoss

Abstract Base Class for all GFN Losses

class gfn.losses.DBParametrization

Bases: gfn.losses.base.PFBasedParametrization

Corresponds to \(\mathcal{O}_{PF} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1\) is the set of functions from the internal states (no \(s_f\)) to \(\mathbb{R}^+\) (which we parametrize with logs, to avoid the non-negativity constraint), and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.logit_PB is a fixed LogitPBEstimator. Useful for the Detailed Balance Loss.

logF :gfn.estimators.LogStateFlowEstimator
class gfn.losses.DetailedBalance(parametrization, on_policy=False)

Bases: gfn.losses.base.EdgeDecomposableLoss

Abstract Base Class for all GFN Losses

Parameters
__call__(transitions)
Parameters

transitions (gfn.containers.Transitions) –

Return type

LossTensor

get_modified_scores(transitions)

DAG-GFN-style detailed balance, for when all states are connected to the sink

Parameters

transitions (gfn.containers.Transitions) –

Return type

ScoresTensor

get_scores(transitions)
Parameters

transitions (gfn.containers.Transitions) –

class gfn.losses.EdgeDecomposableLoss(parametrization)

Bases: Loss, abc.ABC

Abstract Base Class for all GFN Losses

Parameters

parametrization (Parametrization) –

abstract __call__(edges)
Parameters

edges (gfn.containers.transitions.Transitions) –

Return type

torchtyping.TensorType[0, float]

class gfn.losses.FMParametrization

Bases: gfn.losses.base.Parametrization

\(\mathcal{O}_{edge}\) is the set of functions from the non-terminating edges to \(\mathbb{R}^+\). Which is equivalent to the set of functions from the internal nodes (i.e. without \(s_f\)) to \((\mathbb{R})^{n_actions}\), without the exit action (No need for positivity if we parametrize log-flows).

logF :gfn.estimators.LogEdgeFlowEstimator
Pi(env, n_samples=1000, **actions_sampler_kwargs)
Parameters
Return type

gfn.distributions.TrajectoryDistribution

class gfn.losses.FlowMatching(parametrization, alpha=1.0)

Bases: gfn.losses.base.StateDecomposableLoss

Parameters

parametrization (FMParametrization) –

__call__(states_tuple)
Parameters

states_tuple (Tuple[gfn.containers.states.States, gfn.containers.states.States]) –

Return type

LossTensor

flow_matching_loss(states)

Compute the FM for the given states, defined as the log-sum incoming flows minus log-sum outgoing flows. The states should not include s0. The batch shape should be (n_states,).

As of now, only discrete environments are handled.

Parameters

states (gfn.containers.states.States) –

Return type

ScoresTensor

reward_matching_loss(terminating_states)
Parameters

terminating_states (gfn.containers.states.States) –

Return type

LossTensor

class gfn.losses.LogPartitionVarianceLoss(parametrization, log_reward_clip_min=-12, on_policy=False)

Bases: gfn.losses.base.TrajectoryDecomposableLoss

Parameters
__call__(trajectories)
Parameters

trajectories (gfn.containers.Trajectories) –

Return type

LossTensor

class gfn.losses.Loss(parametrization)

Bases: abc.ABC

Abstract Base Class for all GFN Losses

Parameters

parametrization (Parametrization) –

abstract __call__(*args, **kwargs)
Return type

torchtyping.TensorType[0, float]

class gfn.losses.PFBasedParametrization

Bases: Parametrization, abc.ABC

Base class for parametrizations that explicitly uses \(P_F\)

logit_PB :gfn.estimators.LogitPBEstimator
logit_PF :gfn.estimators.LogitPFEstimator
Pi(env, n_samples=1000, **actions_sampler_kwargs)
Parameters
Return type

gfn.distributions.TrajectoryDistribution

class gfn.losses.Parametrization

Bases: abc.ABC

Abstract Base Class for Flow Parametrizations, as defined in Sec. 3 of GFlowNets Foundations. All attributes should be estimators, and should either have a GFNModule or attribute called module, or torch.Tensor attribute called tensor with requires_grad=True.

property parameters: dict

Return a dictionary of all parameters of the parametrization. Note that there might be duplicate parameters (e.g. when two NNs share parameters), in which case the optimizer should take as input set(self.parameters.values()).

Return type

dict

P_T(env, n_samples, **kwargs)
Parameters
Return type

gfn.distributions.TrajectoryBasedTerminatingStateDistribution

abstract Pi(env, n_samples, **kwargs)
Parameters
Return type

gfn.distributions.TrajectoryDistribution

load_state_dict(path)
Parameters

path (str) –

save_state_dict(path)
Parameters

path (str) –

class gfn.losses.StateDecomposableLoss(parametrization)

Bases: Loss, abc.ABC

Abstract Base Class for all GFN Losses

Parameters

parametrization (Parametrization) –

abstract __call__(states_tuple)

Unlike the GFlowNets Foundations paper, we allow more flexibility by passing a tuple of states, the first one being the internal states of the trajectories (i.e. non-terminal states), and the second one being the terminal states of the trajectories. If these two are not handled differently, then they should be concatenated together.

Parameters

states_tuple (Tuple[gfn.containers.states.States, gfn.containers.states.States]) –

Return type

torchtyping.TensorType[0, float]

class gfn.losses.SubTBParametrization

Bases: gfn.losses.base.PFBasedParametrization

Exactly the same as DBParametrization

logF :gfn.estimators.LogStateFlowEstimator
class gfn.losses.SubTrajectoryBalance(parametrization, log_reward_clip_min=-12, weighing='geometric', lamda=0.9, on_policy=False)

Bases: gfn.losses.base.TrajectoryDecomposableLoss

Abstract Base Class for all GFN Losses

Parameters
  • parametrization (SubTBParametrization) –

  • log_reward_clip_min (float) –

  • weighing (Literal[DB, ModifiedDB, TB, geometric, equal, geometric_within, equal_within]) –

  • lamda (float) –

  • on_policy (bool) –

__call__(trajectories)
Parameters

trajectories (gfn.containers.Trajectories) –

Return type

LossTensor

cumulative_logprobs(trajectories, log_p_trajectories)
Parameters
  • trajectories (gfn.containers.Trajectories) – trajectories

  • log_p_trajectories (LogPTrajectoriesTensor) – log probabilities of each transition in each trajectory

Returns

cumulative sum of log probabilities of each trajectory

Return type

LogPTrajectoriesTensor

get_scores(trajectories)

Returns two elements: - A list of tensors, each of which representing the scores of all sub-trajectories of length k, for k in [1, …, trajectories.max_length].

where the score of a sub-trajectory tau is log P_F(tau) + log F(tau_0) - log P_B(tau) - log F(tau_{-1}). The shape of the k-th tensor is (trajectories.max_length - k + 1, trajectories.n_trajectories), k starting from 1.

  • A list of tensors representing what should be masked out in the each element of the first list, given that not all sub-trajectories

    of length k exist for each trajectory. The entries of those tensors are True if the corresponding sub-trajectory does not exist.

Parameters

trajectories (gfn.containers.Trajectories) –

Return type

Tuple[List[ScoresTensor], List[ScoresTensor]]

class gfn.losses.TBParametrization

Bases: gfn.losses.base.PFBasedParametrization

\(\mathcal{O}_{PFZ} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1 = \mathbb{R}\) represents the possible values for logZ, and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.logit_PB is a fixed LogitPBEstimator. Useful for the Trajectory Balance Loss.

logZ :gfn.estimators.LogZEstimator
class gfn.losses.TrajectoryBalance(parametrization, log_reward_clip_min=-12, on_policy=False)

Bases: gfn.losses.base.TrajectoryDecomposableLoss

Parameters
  • parametrization (TBParametrization) –

  • log_reward_clip_min (float) –

  • on_policy (bool) –

__call__(trajectories)
Parameters

trajectories (gfn.containers.Trajectories) –

Return type

LossTensor

class gfn.losses.TrajectoryDecomposableLoss(parametrization)

Bases: Loss, abc.ABC

Abstract Base Class for all GFN Losses

Parameters

parametrization (Parametrization) –

abstract __call__(trajectories)
Parameters

trajectories (gfn.containers.trajectories.Trajectories) –

Return type

torchtyping.TensorType[0, float]

get_pfs_and_pbs(trajectories, fill_value=0.0, temperature=1.0, epsilon=0.0, no_pf=False)

Evaluate log_pf and log_pb for each action in each trajectory in the batch. This is useful when the policy used to sample the trajectories is different from the one used to evaluate the loss.

Parameters
  • trajectories (Trajectories) – Trajectories to evaluate.

  • fill_value (float, optional) – Value to use for invalid states (i.e. s_f that is added to shorter trajectories). Defaults to 0.0.

  • action. (The next parameters correspond to how the actions_sampler evaluates each) –

  • temperature (float, optional) – Temperature to use for the softmax. Defaults to 1.0.

  • epsilon (float, optional) – Epsilon to use for the softmax. Defaults to 0.0.

  • no_pf (bool, optional) – Whether to evaluate log_pf as well. Defaults to False.

Raises

ValueError – if the trajectories are backward.

Returns

A tuple of float tensors of shape (max_length, n_trajectories) containing the log_pf and log_pb for each action in each trajectory. The first one can be None.

Return type

Tuple[LogPTrajectoriesTensor | None, LogPTrajectoriesTensor]

get_trajectories_scores(trajectories)
Parameters

trajectories (gfn.containers.trajectories.Trajectories) –

Return type

Tuple[ScoresTensor, ScoresTensor, ScoresTensor]