gfn.gflownet.base

Module Contents

Classes

GFlowNet

Abstract Base Class for GFlowNets.

PFBasedGFlowNet

Base class for gflownets that explicitly uses $P_F$.

TrajectoryBasedGFlowNet

Base class for gflownets that explicitly uses $P_F$.

Attributes

TrainingSampleType

class gfn.gflownet.base.GFlowNet

Bases: abc.ABC, torch.nn.Module, Generic[TrainingSampleType]

Abstract Base Class for GFlowNets.

A formal definition of GFlowNets is given in Sec. 3 of [GFlowNet Foundations](https://arxiv.org/pdf/2111.09266).

abstract loss(env, training_objects)

Computes the loss given the training objects.

Parameters

env (gfn.env.Env) –

sample_terminating_states(env, n_samples)

Rolls out the parametrization’s policy and returns the terminating states.

Parameters
  • env (gfn.env.Env) – the environment to sample terminating states from.

  • n_samples (int) – number of terminating states to be sampled.

Returns

sampled terminating states object.

Return type

States

abstract sample_trajectories(env, n_samples)

Sample a specific number of complete trajectories.

Parameters
  • env (gfn.env.Env) – the environment to sample trajectories from.

  • n_samples (int) – number of trajectories to be sampled.

Returns

sampled trajectories object.

Return type

Trajectories

abstract to_training_samples(trajectories)

Converts trajectories to training samples. The type depends on the GFlowNet.

Parameters

trajectories (gfn.containers.Trajectories) –

Return type

TrainingSampleType

class gfn.gflownet.base.PFBasedGFlowNet(pf, pb, on_policy=False)

Bases: GFlowNet

Base class for gflownets that explicitly uses \(P_F\).

Parameters
pf

GFNModule

pb

GFNModule

sample_trajectories(env, n_samples)

Sample a specific number of complete trajectories.

Parameters
  • env (gfn.env.Env) – the environment to sample trajectories from.

  • n_samples (int) – number of trajectories to be sampled.

Returns

sampled trajectories object.

Return type

Trajectories

gfn.gflownet.base.TrainingSampleType
class gfn.gflownet.base.TrajectoryBasedGFlowNet(pf, pb, on_policy=False)

Bases: PFBasedGFlowNet[gfn.containers.Trajectories]

Base class for gflownets that explicitly uses \(P_F\).

Parameters
pf

GFNModule

pb

GFNModule

get_pfs_and_pbs(trajectories, fill_value=0.0)

Evaluates logprobs for each transition in each trajectory in the batch.

More specifically it evaluates \(\log P_F (s' \mid s)\) and \(\log P_B(s \mid s')\) for each transition in each trajectory in the batch.

Useful when the policy used to sample the trajectories is different from the one used to evaluate the loss. Otherwise we can use the logprobs directly from the trajectories.

Parameters
  • trajectories (gfn.containers.Trajectories) – Trajectories to evaluate.

  • fill_value (float) – Value to use for invalid states (i.e. \(s_f\) that is added to shorter trajectories).

Return type

Tuple[torchtyping.TensorType[max_length, n_trajectories, torch.float], torchtyping.TensorType[max_length, n_trajectories, torch.float]]

Returns: A tuple of float tensors of shape (max_length, n_trajectories) containing

the log_pf and log_pb for each action in each trajectory. The first one can be None.

Raises
  • ValueError – if the trajectories are backward.

  • AssertionError – when actions and states dimensions mismatch.

Parameters
Return type

Tuple[torchtyping.TensorType[max_length, n_trajectories, torch.float], torchtyping.TensorType[max_length, n_trajectories, torch.float]]

get_trajectories_scores(trajectories)

Given a batch of trajectories, calculate forward & backward policy scores.

Parameters

trajectories (gfn.containers.Trajectories) –

Return type

Tuple[torchtyping.TensorType[n_trajectories, torch.float], torchtyping.TensorType[n_trajectories, torch.float], torchtyping.TensorType[n_trajectories, torch.float]]

to_training_samples(trajectories)

Converts trajectories to training samples. The type depends on the GFlowNet.

Parameters

trajectories (gfn.containers.Trajectories) –

Return type

gfn.containers.Trajectories