gfn.containers

Submodules

Package Contents

Classes

States

Base class for states, seen as nodes of the DAG.

Trajectories

Base class for states containers (states, transitions, or trajectories)

Transitions

Base class for states containers (states, transitions, or trajectories)

class gfn.containers.States(states_tensor, forward_masks=None, backward_masks=None)

Bases: gfn.containers.base.Container, abc.ABC

Base class for states, seen as nodes of the DAG. For each environment, a States subclass is needed. A States object is a collection of multiple states (nodes of the DAG). A tensor representation of the states is required for batching. If a state is represented with a tensor of shape (*state_shape), a batch of states is represented with a States object, with the attribute states_tensor of shape (*batch_shape, *state_shape). Other representations are possible (e.g. state as string, as numpy array, as graph, etc…), but these representations should not be batched.

If the environment’s action space is discrete, then each States object is also endowed with a forward_masks and backward_masks boolean attributes representing which actions are allowed at each state.

A batch_shape attribute is also required, to keep track of the batch dimension. A trajectory can be represented by a States object with batch_shape = (n_states,). Multiple trajectories can be represented by a States object with batch_shape = (n_states, n_trajectories).

Because multiple trajectories can have different lengths, batching requires appending a dummy tensor to trajectories that are shorter than the longest trajectory. The dummy state is the s_f attribute of the environment (e.g. [-1, …, -1], or [-inf, …, -inf], etc…). Which is never processed, and is used to pad the batch of states only.

Parameters
  • states_tensor (StatesTensor) –

  • forward_masks (ForwardMasksTensor | None) –

  • backward_masks (BackwardMasksTensor | None) –

property device: torch.device
Return type

torch.device

property is_initial_state: DonesTensor

Return a boolean tensor of shape=(*batch_shape,), where True means that the state is \(s_0\) of the DAG.

Return type

DonesTensor

property is_sink_state: DonesTensor

Return a boolean tensor of shape=(*batch_shape,), where True means that the state is \(s_f\) of the DAG.

Return type

DonesTensor

property log_rewards: RewardsTensor
Return type

RewardsTensor

s0 :ClassVar[OneStateTensor]
sf :ClassVar[OneStateTensor]
state_shape :ClassVar[tuple[int, Ellipsis]]
__getitem__(index)

Access particular states of the batch.

Parameters

index (int | Sequence[int] | Sequence[bool]) –

Return type

States

__len__()

Returns the number of elements in the container

__repr__()

Return repr(self).

compare(other)

Given a tensor of states, returns a tensor of booleans indicating whether the states are equal to the states in self.

Parameters

other (StatesTensor) – Tensor of states to compare to.

Returns

Tensor of booleans indicating whether the states are equal to the states in self.

Return type

DonesTensor

extend(other)

Collates to another States object of the same batch shape, which should be 1 or 2.

Parameters

other (States) – Batch of states to collate to.

Raises

ValueError – if self.batch_shape != other.batch_shape or if self.batch_shape != (1,) or (2,)

Return type

None

extend_with_sf(required_first_dim)

Takes a two-dimensional batch of states (i.e. of batch_shape (a, b)), and extends it to a States object of batch_shape (required_first_dim, b), by adding the required number of s_f tensors. This is useful to extend trajectories of different lengths.

Parameters

required_first_dim (int) –

Return type

None

flatten()

Flatten the batch dimension of the states. This is useful for example when extracting individual states from trajectories.

Return type

States

classmethod from_batch_shape(batch_shape, random=False)

Create a States object with the given batch shape, all initialized to s_0. If random is True, the states are initialized randomly. This requires that the environment implements the make_random_states_tensor class method.

Parameters
  • batch_shape (tuple[int]) –

  • random (bool) –

Return type

States

classmethod make_initial_states_tensor(batch_shape)
Parameters

batch_shape (tuple[int]) –

Return type

StatesTensor

make_masks()

Create the forward and backward masks for the states. This method is called only if the masks are not provided at initialization.

Return type

tuple[ForwardMasksTensor, BackwardMasksTensor]

abstract classmethod make_random_states_tensor(batch_shape)
Parameters

batch_shape (tuple[int]) –

Return type

StatesTensor

update_masks()

Update the masks, if necessary. This method should be called after each action is taken.

Return type

None

class gfn.containers.Trajectories(env, states=None, actions=None, when_is_done=None, is_backward=False, log_rewards=None, log_probs=None)

Bases: gfn.containers.base.Container

Base class for states containers (states, transitions, or trajectories)

Parameters
  • env (gfn.envs.Env) –

  • states (States | None) –

  • actions (Tensor2D | None) –

  • when_is_done (Tensor1D | None) –

  • is_backward (bool) –

  • log_rewards (FloatTensor1D | None) –

  • log_probs (FloatTensor2D | None) –

property last_states: gfn.containers.states.States
Return type

gfn.containers.states.States

property log_rewards: FloatTensor1D | None
Return type

FloatTensor1D | None

property max_length: int
Return type

int

property n_trajectories: int
Return type

int

__getitem__(index)

Returns a subset of the n_trajectories trajectories.

Parameters

index (int | Sequence[int]) –

Return type

Trajectories

__len__()

Returns the number of elements in the container

Return type

int

__repr__()

Return repr(self).

Return type

str

extend(other)

Extend the trajectories with another set of trajectories.

Parameters

other (Trajectories) –

Return type

None

extend_actions(required_first_dim)

Extends the actions and log_probs along the first dimension by by adding -1s as necessary. This is useful for extending trajectories of different lengths.

Parameters

required_first_dim (int) –

Return type

None

static revert_backward_trajectories(trajectories)
Parameters

trajectories (Trajectories) –

Return type

Trajectories

to_non_initial_intermediary_and_terminating_states()

Returns a tuple of States objects from the trajectories, containing all non-initial intermediary and all terminating states in the trajectories

Returns

  • All the intermediary states in the trajectories that are not s0.
    • All the terminating states in the trajectories that are not s0.

Return type

Tuple[States, States]

to_states()

Returns a States object from the trajectories, containing all states in the trajectories

Return type

gfn.containers.states.States

to_transitions()

Returns a Transitions object from the trajectories

Return type

gfn.containers.transitions.Transitions

class gfn.containers.Transitions(env, states=None, actions=None, is_done=None, next_states=None, is_backward=False, log_rewards=None, log_probs=None)

Bases: gfn.containers.base.Container

Base class for states containers (states, transitions, or trajectories)

Parameters
  • env (gfn.envs.Env) –

  • states (States | None) –

  • actions (LongTensor | None) –

  • is_done (BoolTensor | None) –

  • next_states (States | None) –

  • is_backward (bool) –

  • log_rewards (FloatTensor | None) –

  • log_probs (FloatTensor | None) –

property all_log_rewards: PairFloatTensor

This is applicable to environments where all states are terminating. This function evaluates the rewards for all transitions that do not end in the sink state. This is useful for the Modified Detailed Balance loss.

Return type

PairFloatTensor

property last_states: gfn.containers.states.States

Get the last states, i.e. terminating states

Return type

gfn.containers.states.States

property log_rewards: FloatTensor | None
Return type

FloatTensor | None

property n_transitions: int
Return type

int

__getitem__(index)

Access particular transitions of the batch.

Parameters

index (int | Sequence[int]) –

Return type

Transitions

__len__()

Returns the number of elements in the container

Return type

int

__repr__()

Return repr(self).

extend(other)

Extend the Transitions object with another Transitions object.

Parameters

other (Transitions) –

Return type

None