gfn.losses.base

Module Contents

Classes

EdgeDecomposableLoss

Abstract Base Class for all GFN Losses

Loss

Abstract Base Class for all GFN Losses

PFBasedParametrization

Base class for parametrizations that explicitly uses $P_F$

Parametrization

Abstract Base Class for Flow Parametrizations,

StateDecomposableLoss

Abstract Base Class for all GFN Losses

TrajectoryDecomposableLoss

Abstract Base Class for all GFN Losses

Attributes

LogPTrajectoriesTensor

ScoresTensor

class gfn.losses.base.EdgeDecomposableLoss(parametrization)

Bases: Loss, abc.ABC

Abstract Base Class for all GFN Losses

Parameters

parametrization (Parametrization) –

abstract __call__(edges)
Parameters

edges (gfn.containers.transitions.Transitions) –

Return type

torchtyping.TensorType[0, float]

gfn.losses.base.LogPTrajectoriesTensor
class gfn.losses.base.Loss(parametrization)

Bases: abc.ABC

Abstract Base Class for all GFN Losses

Parameters

parametrization (Parametrization) –

abstract __call__(*args, **kwargs)
Return type

torchtyping.TensorType[0, float]

class gfn.losses.base.PFBasedParametrization

Bases: Parametrization, abc.ABC

Base class for parametrizations that explicitly uses \(P_F\)

logit_PB :gfn.estimators.LogitPBEstimator
logit_PF :gfn.estimators.LogitPFEstimator
Pi(env, n_samples=1000, **actions_sampler_kwargs)
Parameters
Return type

gfn.distributions.TrajectoryDistribution

class gfn.losses.base.Parametrization

Bases: abc.ABC

Abstract Base Class for Flow Parametrizations, as defined in Sec. 3 of GFlowNets Foundations. All attributes should be estimators, and should either have a GFNModule or attribute called module, or torch.Tensor attribute called tensor with requires_grad=True.

property parameters: dict

Return a dictionary of all parameters of the parametrization. Note that there might be duplicate parameters (e.g. when two NNs share parameters), in which case the optimizer should take as input set(self.parameters.values()).

Return type

dict

P_T(env, n_samples, **kwargs)
Parameters
Return type

gfn.distributions.TrajectoryBasedTerminatingStateDistribution

abstract Pi(env, n_samples, **kwargs)
Parameters
Return type

gfn.distributions.TrajectoryDistribution

load_state_dict(path)
Parameters

path (str) –

save_state_dict(path)
Parameters

path (str) –

gfn.losses.base.ScoresTensor
class gfn.losses.base.StateDecomposableLoss(parametrization)

Bases: Loss, abc.ABC

Abstract Base Class for all GFN Losses

Parameters

parametrization (Parametrization) –

abstract __call__(states_tuple)

Unlike the GFlowNets Foundations paper, we allow more flexibility by passing a tuple of states, the first one being the internal states of the trajectories (i.e. non-terminal states), and the second one being the terminal states of the trajectories. If these two are not handled differently, then they should be concatenated together.

Parameters

states_tuple (Tuple[gfn.containers.states.States, gfn.containers.states.States]) –

Return type

torchtyping.TensorType[0, float]

class gfn.losses.base.TrajectoryDecomposableLoss(parametrization)

Bases: Loss, abc.ABC

Abstract Base Class for all GFN Losses

Parameters

parametrization (Parametrization) –

abstract __call__(trajectories)
Parameters

trajectories (gfn.containers.trajectories.Trajectories) –

Return type

torchtyping.TensorType[0, float]

get_pfs_and_pbs(trajectories, fill_value=0.0, temperature=1.0, epsilon=0.0, no_pf=False)

Evaluate log_pf and log_pb for each action in each trajectory in the batch. This is useful when the policy used to sample the trajectories is different from the one used to evaluate the loss.

Parameters
  • trajectories (Trajectories) – Trajectories to evaluate.

  • fill_value (float, optional) – Value to use for invalid states (i.e. s_f that is added to shorter trajectories). Defaults to 0.0.

  • action. (The next parameters correspond to how the actions_sampler evaluates each) –

  • temperature (float, optional) – Temperature to use for the softmax. Defaults to 1.0.

  • epsilon (float, optional) – Epsilon to use for the softmax. Defaults to 0.0.

  • no_pf (bool, optional) – Whether to evaluate log_pf as well. Defaults to False.

Raises

ValueError – if the trajectories are backward.

Returns

A tuple of float tensors of shape (max_length, n_trajectories) containing the log_pf and log_pb for each action in each trajectory. The first one can be None.

Return type

Tuple[LogPTrajectoriesTensor | None, LogPTrajectoriesTensor]

get_trajectories_scores(trajectories)
Parameters

trajectories (gfn.containers.trajectories.Trajectories) –

Return type

Tuple[ScoresTensor, ScoresTensor, ScoresTensor]