gfn.containers
Submodules
Package Contents
Classes
A replay buffer of trajectories or transitions. |
|
Container for complete trajectories (starting in $s_0$ and ending in $s_f$). |
|
Container for the transitions. |
- class gfn.containers.ReplayBuffer(env, objects_type, capacity=1000)
A replay buffer of trajectories or transitions.
- Parameters
env (gfn.env.Env) –
objects_type (Literal[transitions, trajectories, states]) –
capacity (int) –
- env
the Environment instance.
- loss_fn
the Loss instance
- capacity
the size of the buffer.
- training_objects
the buffer of objects used for training.
- terminating_states
a States class representation of \(s_f\).
- objects_type
the type of buffer (transitions, trajectories, or states).
- __len__()
- __repr__()
Return repr(self).
- add(training_objects)
Adds a training object to the buffer.
- Parameters
training_objects (Transitions | Trajectories | tuple[States]) –
- load(directory)
Loads the buffer from disk.
- Parameters
directory (str) –
- sample(n_trajectories)
Samples n_trajectories training objects from the buffer.
- Parameters
n_trajectories (int) –
- Return type
Transitions | Trajectories | tuple[States]
- save(directory)
Saves the buffer to disk.
- Parameters
directory (str) –
- class gfn.containers.Trajectories(env, states=None, actions=None, when_is_done=None, is_backward=False, log_rewards=None, log_probs=None)
Bases:
gfn.containers.base.ContainerContainer for complete trajectories (starting in \(s_0\) and ending in \(s_f\)).
Trajectories are represented as a States object with bi-dimensional batch shape. Actions are represented as an Actions object with bi-dimensional batch shape. The first dimension represents the time step, the second dimension represents the trajectory index. Because different trajectories may have different lengths, shorter trajectories are padded with the tensor representation of the terminal state (\(s_f\) or \(s_0\) depending on the direction of the trajectory), and actions is appended with dummy actions. The when_is_done tensor represents the time step at which each trajectory ends.
- Parameters
env (gfn.env.Env) –
states (States | None) –
actions (Actions | None) –
when_is_done (TT['n_trajectories', torch.long] | None) –
is_backward (bool) –
log_rewards (TT['n_trajectories', torch.float] | None) –
log_probs (TT['max_length', 'n_trajectories', torch.float] | None) –
- env
The environment in which the trajectories are defined.
- states
The states of the trajectories.
- actions
The actions of the trajectories.
- when_is_done
The time step at which each trajectory ends.
- is_backward
Whether the trajectories are backward or forward.
- log_rewards
The log_rewards of the trajectories.
- Return type
TT[‘n_trajectories’, torch.float] | None
- log_probs
The log probabilities of the trajectories’ actions.
- property last_states: gfn.states.States
- Return type
- property log_rewards: TT['n_trajectories', torch.float] | None
- Return type
TT[‘n_trajectories’, torch.float] | None
- property max_length: int
- Return type
int
- property n_trajectories: int
- Return type
int
- __getitem__(index)
Returns a subset of the n_trajectories trajectories.
- Parameters
index (int | Sequence[int]) –
- Return type
- __len__()
Returns the number of elements in the container.
- Return type
int
- __repr__()
Return repr(self).
- Return type
str
- extend(other)
Extend the trajectories with another set of trajectories.
Extends along all attributes in turn (actions, states, when_is_done, log_probs, log_rewards).
- Parameters
other (Trajectories) – an external set of Trajectories.
- Return type
None
- static extend_log_probs(log_probs, new_max_length)
Extend the log_probs matrix by adding 0 until the required length is reached.
- Parameters
log_probs (torchtyping.TensorType[max_length, n_trajectories, torch.float]) –
new_max_length (int) –
- Return type
torchtyping.TensorType[max_max_length, n_trajectories, torch.float]
- to_non_initial_intermediary_and_terminating_states()
Returns all intermediate and terminating States from the trajectories.
This is useful for the flow matching loss, that requires its inputs to be distinguished.
- Returns: a tuple containing all the intermediary states in the trajectories
that are not s0, and all the terminating states in the trajectories that are not s0.
- Return type
tuple[gfn.states.States, gfn.states.States]
- to_states()
Returns a States object from the trajectories, containing all states in the trajectories
- Return type
- to_transitions()
Returns a Transitions object from the trajectories.
- Return type
- class gfn.containers.Transitions(env, states=None, actions=None, is_done=None, next_states=None, is_backward=False, log_rewards=None, log_probs=None)
Bases:
gfn.containers.base.ContainerContainer for the transitions.
- Parameters
env (gfn.env.Env) –
states (States | None) –
actions (Actions | None) –
is_done (TT['n_transitions', torch.bool] | None) –
next_states (States | None) –
is_backward (bool) –
log_rewards (TT['n_transitions', torch.float] | None) –
log_probs (TT['n_transitions', torch.float] | None) –
- env
environment.
- is_backward
Whether the transitions are backward transitions (i.e. next_states is the parent of states).
- states
States object with uni-dimensional batch_shape, representing the parents of the transitions.
- actions
Actions chosen at the parents of each transitions.
- is_done
Whether the action is the exit action.
- next_states
States object with uni-dimensional batch_shape, representing the children of the transitions.
- log_probs
The log-probabilities of the actions.
- property all_log_rewards: torchtyping.TensorType[n_transitions, 2, torch.float]
Calculate all log rewards for the transitions.
This is applicable to environments where all states are terminating. This function evaluates the rewards for all transitions that do not end in the sink state. This is useful for the Modified Detailed Balance loss.
- Raises
NotImplementedError – when used for backward transitions.
- Return type
torchtyping.TensorType[n_transitions, 2, torch.float]
- property last_states: gfn.states.States
Get the last states, i.e. terminating states
- Return type
- property log_rewards: TT['n_transitions', torch.float] | None
- Return type
TT[‘n_transitions’, torch.float] | None
- property n_transitions: int
- Return type
int
- __getitem__(index)
Access particular transitions of the batch.
- Parameters
index (int | Sequence[int]) –
- Return type
- __len__()
- Return type
int
- __repr__()
- extend(other)
Extend the Transitions object with another Transitions object.
- Parameters
other (Transitions) –
- Return type
None