gfn.containers.states

Module Contents

Classes

States

Base class for states, seen as nodes of the DAG.

Functions

correct_cast(forward_masks, backward_masks)

Casts the given masks to the correct type, if they are not None.

Attributes

BackwardMasksTensor

DonesTensor

ForwardMasksTensor

OneStateTensor

RewardsTensor

StatesTensor

gfn.containers.states.BackwardMasksTensor
gfn.containers.states.DonesTensor
gfn.containers.states.ForwardMasksTensor
gfn.containers.states.OneStateTensor
gfn.containers.states.RewardsTensor
class gfn.containers.states.States(states_tensor, forward_masks=None, backward_masks=None)

Bases: gfn.containers.base.Container, abc.ABC

Base class for states, seen as nodes of the DAG. For each environment, a States subclass is needed. A States object is a collection of multiple states (nodes of the DAG). A tensor representation of the states is required for batching. If a state is represented with a tensor of shape (*state_shape), a batch of states is represented with a States object, with the attribute states_tensor of shape (*batch_shape, *state_shape). Other representations are possible (e.g. state as string, as numpy array, as graph, etc…), but these representations should not be batched.

If the environment’s action space is discrete, then each States object is also endowed with a forward_masks and backward_masks boolean attributes representing which actions are allowed at each state.

A batch_shape attribute is also required, to keep track of the batch dimension. A trajectory can be represented by a States object with batch_shape = (n_states,). Multiple trajectories can be represented by a States object with batch_shape = (n_states, n_trajectories).

Because multiple trajectories can have different lengths, batching requires appending a dummy tensor to trajectories that are shorter than the longest trajectory. The dummy state is the s_f attribute of the environment (e.g. [-1, …, -1], or [-inf, …, -inf], etc…). Which is never processed, and is used to pad the batch of states only.

Parameters
  • states_tensor (StatesTensor) –

  • forward_masks (ForwardMasksTensor | None) –

  • backward_masks (BackwardMasksTensor | None) –

property device: torch.device
Return type

torch.device

property is_initial_state: DonesTensor

Return a boolean tensor of shape=(*batch_shape,), where True means that the state is \(s_0\) of the DAG.

Return type

DonesTensor

property is_sink_state: DonesTensor

Return a boolean tensor of shape=(*batch_shape,), where True means that the state is \(s_f\) of the DAG.

Return type

DonesTensor

property log_rewards: RewardsTensor
Return type

RewardsTensor

s0 :ClassVar[OneStateTensor]
sf :ClassVar[OneStateTensor]
state_shape :ClassVar[tuple[int, Ellipsis]]
__getitem__(index)

Access particular states of the batch.

Parameters

index (int | Sequence[int] | Sequence[bool]) –

Return type

States

__len__()

Returns the number of elements in the container

__repr__()

Return repr(self).

compare(other)

Given a tensor of states, returns a tensor of booleans indicating whether the states are equal to the states in self.

Parameters

other (StatesTensor) – Tensor of states to compare to.

Returns

Tensor of booleans indicating whether the states are equal to the states in self.

Return type

DonesTensor

extend(other)

Collates to another States object of the same batch shape, which should be 1 or 2.

Parameters

other (States) – Batch of states to collate to.

Raises

ValueError – if self.batch_shape != other.batch_shape or if self.batch_shape != (1,) or (2,)

Return type

None

extend_with_sf(required_first_dim)

Takes a two-dimensional batch of states (i.e. of batch_shape (a, b)), and extends it to a States object of batch_shape (required_first_dim, b), by adding the required number of s_f tensors. This is useful to extend trajectories of different lengths.

Parameters

required_first_dim (int) –

Return type

None

flatten()

Flatten the batch dimension of the states. This is useful for example when extracting individual states from trajectories.

Return type

States

classmethod from_batch_shape(batch_shape, random=False)

Create a States object with the given batch shape, all initialized to s_0. If random is True, the states are initialized randomly. This requires that the environment implements the make_random_states_tensor class method.

Parameters
  • batch_shape (tuple[int]) –

  • random (bool) –

Return type

States

classmethod make_initial_states_tensor(batch_shape)
Parameters

batch_shape (tuple[int]) –

Return type

StatesTensor

make_masks()

Create the forward and backward masks for the states. This method is called only if the masks are not provided at initialization.

Return type

tuple[ForwardMasksTensor, BackwardMasksTensor]

abstract classmethod make_random_states_tensor(batch_shape)
Parameters

batch_shape (tuple[int]) –

Return type

StatesTensor

update_masks()

Update the masks, if necessary. This method should be called after each action is taken.

Return type

None

gfn.containers.states.StatesTensor
gfn.containers.states.correct_cast(forward_masks, backward_masks)

Casts the given masks to the correct type, if they are not None. This function is to help with type checking only.

Parameters
  • forward_masks (ForwardMasksTensor | None) –

  • backward_masks (BackwardMasksTensor | None) –

Return type

tuple[ForwardMasksTensor, BackwardMasksTensor]