gfn.containers

Submodules

Package Contents

Classes

ReplayBuffer

A replay buffer of trajectories or transitions.

Trajectories

Container for complete trajectories (starting in $s_0$ and ending in $s_f$).

Transitions

Container for the transitions.

class gfn.containers.ReplayBuffer(env, objects_type, capacity=1000)

A replay buffer of trajectories or transitions.

Parameters
  • env (gfn.env.Env) –

  • objects_type (Literal[transitions, trajectories, states]) –

  • capacity (int) –

env

the Environment instance.

loss_fn

the Loss instance

capacity

the size of the buffer.

training_objects

the buffer of objects used for training.

terminating_states

a States class representation of \(s_f\).

objects_type

the type of buffer (transitions, trajectories, or states).

__len__()
__repr__()

Return repr(self).

add(training_objects)

Adds a training object to the buffer.

Parameters

training_objects (Transitions | Trajectories | tuple[States]) –

load(directory)

Loads the buffer from disk.

Parameters

directory (str) –

sample(n_trajectories)

Samples n_trajectories training objects from the buffer.

Parameters

n_trajectories (int) –

Return type

Transitions | Trajectories | tuple[States]

save(directory)

Saves the buffer to disk.

Parameters

directory (str) –

class gfn.containers.Trajectories(env, states=None, actions=None, when_is_done=None, is_backward=False, log_rewards=None, log_probs=None)

Bases: gfn.containers.base.Container

Container for complete trajectories (starting in \(s_0\) and ending in \(s_f\)).

Trajectories are represented as a States object with bi-dimensional batch shape. Actions are represented as an Actions object with bi-dimensional batch shape. The first dimension represents the time step, the second dimension represents the trajectory index. Because different trajectories may have different lengths, shorter trajectories are padded with the tensor representation of the terminal state (\(s_f\) or \(s_0\) depending on the direction of the trajectory), and actions is appended with dummy actions. The when_is_done tensor represents the time step at which each trajectory ends.

Parameters
  • env (gfn.env.Env) –

  • states (States | None) –

  • actions (Actions | None) –

  • when_is_done (TT['n_trajectories', torch.long] | None) –

  • is_backward (bool) –

  • log_rewards (TT['n_trajectories', torch.float] | None) –

  • log_probs (TT['max_length', 'n_trajectories', torch.float] | None) –

env

The environment in which the trajectories are defined.

states

The states of the trajectories.

actions

The actions of the trajectories.

when_is_done

The time step at which each trajectory ends.

is_backward

Whether the trajectories are backward or forward.

log_rewards

The log_rewards of the trajectories.

Return type

TT[‘n_trajectories’, torch.float] | None

log_probs

The log probabilities of the trajectories’ actions.

property last_states: gfn.states.States
Return type

gfn.states.States

property log_rewards: TT['n_trajectories', torch.float] | None
Return type

TT[‘n_trajectories’, torch.float] | None

property max_length: int
Return type

int

property n_trajectories: int
Return type

int

__getitem__(index)

Returns a subset of the n_trajectories trajectories.

Parameters

index (int | Sequence[int]) –

Return type

Trajectories

__len__()

Returns the number of elements in the container.

Return type

int

__repr__()

Return repr(self).

Return type

str

extend(other)

Extend the trajectories with another set of trajectories.

Extends along all attributes in turn (actions, states, when_is_done, log_probs, log_rewards).

Parameters

other (Trajectories) – an external set of Trajectories.

Return type

None

static extend_log_probs(log_probs, new_max_length)

Extend the log_probs matrix by adding 0 until the required length is reached.

Parameters
  • log_probs (torchtyping.TensorType[max_length, n_trajectories, torch.float]) –

  • new_max_length (int) –

Return type

torchtyping.TensorType[max_max_length, n_trajectories, torch.float]

to_non_initial_intermediary_and_terminating_states()

Returns all intermediate and terminating States from the trajectories.

This is useful for the flow matching loss, that requires its inputs to be distinguished.

Returns: a tuple containing all the intermediary states in the trajectories

that are not s0, and all the terminating states in the trajectories that are not s0.

Return type

tuple[gfn.states.States, gfn.states.States]

to_states()

Returns a States object from the trajectories, containing all states in the trajectories

Return type

gfn.states.States

to_transitions()

Returns a Transitions object from the trajectories.

Return type

gfn.containers.transitions.Transitions

class gfn.containers.Transitions(env, states=None, actions=None, is_done=None, next_states=None, is_backward=False, log_rewards=None, log_probs=None)

Bases: gfn.containers.base.Container

Container for the transitions.

Parameters
  • env (gfn.env.Env) –

  • states (States | None) –

  • actions (Actions | None) –

  • is_done (TT['n_transitions', torch.bool] | None) –

  • next_states (States | None) –

  • is_backward (bool) –

  • log_rewards (TT['n_transitions', torch.float] | None) –

  • log_probs (TT['n_transitions', torch.float] | None) –

env

environment.

is_backward

Whether the transitions are backward transitions (i.e. next_states is the parent of states).

states

States object with uni-dimensional batch_shape, representing the parents of the transitions.

actions

Actions chosen at the parents of each transitions.

is_done

Whether the action is the exit action.

next_states

States object with uni-dimensional batch_shape, representing the children of the transitions.

log_probs

The log-probabilities of the actions.

property all_log_rewards: torchtyping.TensorType[n_transitions, 2, torch.float]

Calculate all log rewards for the transitions.

This is applicable to environments where all states are terminating. This function evaluates the rewards for all transitions that do not end in the sink state. This is useful for the Modified Detailed Balance loss.

Raises

NotImplementedError – when used for backward transitions.

Return type

torchtyping.TensorType[n_transitions, 2, torch.float]

property last_states: gfn.states.States

Get the last states, i.e. terminating states

Return type

gfn.states.States

property log_rewards: TT['n_transitions', torch.float] | None
Return type

TT[‘n_transitions’, torch.float] | None

property n_transitions: int
Return type

int

__getitem__(index)

Access particular transitions of the batch.

Parameters

index (int | Sequence[int]) –

Return type

Transitions

__len__()
Return type

int

__repr__()
extend(other)

Extend the Transitions object with another Transitions object.

Parameters

other (Transitions) –

Return type

None