gfn.gflownet.base
Module Contents
Classes
Abstract Base Class for GFlowNets. |
|
Base class for gflownets that explicitly uses $P_F$. |
|
Base class for gflownets that explicitly uses $P_F$. |
Attributes
- class gfn.gflownet.base.GFlowNet
Bases:
abc.ABC,torch.nn.Module,Generic[TrainingSampleType]Abstract Base Class for GFlowNets.
A formal definition of GFlowNets is given in Sec. 3 of [GFlowNet Foundations](https://arxiv.org/pdf/2111.09266).
- abstract loss(env, training_objects)
Computes the loss given the training objects.
- Parameters
env (gfn.env.Env) –
- sample_terminating_states(env, n_samples)
Rolls out the parametrization’s policy and returns the terminating states.
- Parameters
env (gfn.env.Env) – the environment to sample terminating states from.
n_samples (int) – number of terminating states to be sampled.
- Returns
sampled terminating states object.
- Return type
- abstract sample_trajectories(env, n_samples)
Sample a specific number of complete trajectories.
- Parameters
env (gfn.env.Env) – the environment to sample trajectories from.
n_samples (int) – number of trajectories to be sampled.
- Returns
sampled trajectories object.
- Return type
- abstract to_training_samples(trajectories)
Converts trajectories to training samples. The type depends on the GFlowNet.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
TrainingSampleType
- class gfn.gflownet.base.PFBasedGFlowNet(pf, pb, on_policy=False)
Bases:
GFlowNetBase class for gflownets that explicitly uses \(P_F\).
- Parameters
pf (gfn.modules.GFNModule) –
pb (gfn.modules.GFNModule) –
on_policy (bool) –
- pf
GFNModule
- pb
GFNModule
- sample_trajectories(env, n_samples)
Sample a specific number of complete trajectories.
- Parameters
env (gfn.env.Env) – the environment to sample trajectories from.
n_samples (int) – number of trajectories to be sampled.
- Returns
sampled trajectories object.
- Return type
- gfn.gflownet.base.TrainingSampleType
- class gfn.gflownet.base.TrajectoryBasedGFlowNet(pf, pb, on_policy=False)
Bases:
PFBasedGFlowNet[gfn.containers.Trajectories]Base class for gflownets that explicitly uses \(P_F\).
- Parameters
pf (gfn.modules.GFNModule) –
pb (gfn.modules.GFNModule) –
on_policy (bool) –
- pf
GFNModule
- pb
GFNModule
- get_pfs_and_pbs(trajectories, fill_value=0.0)
Evaluates logprobs for each transition in each trajectory in the batch.
More specifically it evaluates \(\log P_F (s' \mid s)\) and \(\log P_B(s \mid s')\) for each transition in each trajectory in the batch.
Useful when the policy used to sample the trajectories is different from the one used to evaluate the loss. Otherwise we can use the logprobs directly from the trajectories.
- Parameters
trajectories (gfn.containers.Trajectories) – Trajectories to evaluate.
fill_value (float) – Value to use for invalid states (i.e. \(s_f\) that is added to shorter trajectories).
- Return type
Tuple[torchtyping.TensorType[max_length, n_trajectories, torch.float], torchtyping.TensorType[max_length, n_trajectories, torch.float]]
- Returns: A tuple of float tensors of shape (max_length, n_trajectories) containing
the log_pf and log_pb for each action in each trajectory. The first one can be None.
- Raises
ValueError – if the trajectories are backward.
AssertionError – when actions and states dimensions mismatch.
- Parameters
trajectories (gfn.containers.Trajectories) –
fill_value (float) –
- Return type
Tuple[torchtyping.TensorType[max_length, n_trajectories, torch.float], torchtyping.TensorType[max_length, n_trajectories, torch.float]]
- get_trajectories_scores(trajectories)
Given a batch of trajectories, calculate forward & backward policy scores.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type
Tuple[torchtyping.TensorType[n_trajectories, torch.float], torchtyping.TensorType[n_trajectories, torch.float], torchtyping.TensorType[n_trajectories, torch.float]]
- to_training_samples(trajectories)
Converts trajectories to training samples. The type depends on the GFlowNet.
- Parameters
trajectories (gfn.containers.Trajectories) –
- Return type