gfn.states

Module Contents

Classes

DiscreteStates

Base class for states of discrete environments.

States

Base class for states, seen as nodes of the DAG.

class gfn.states.DiscreteStates(tensor, forward_masks=None, backward_masks=None)

Bases: States, abc.ABC

Base class for states of discrete environments.

States are endowed with a forward_masks and backward_masks: boolean attributes representing which actions are allowed at each state. This is the mechanism by which all elements of the library (including an environment’s validate_actions method) verifies the allowed actions at each state.

Parameters
  • tensor (torchtyping.TensorType[batch_shape, state_shape, torch.float]) –

  • forward_masks (Optional[torchtyping.TensorType[batch_shape, n_actions, torch.bool]]) –

  • backward_masks (Optional[torchtyping.TensorType[batch_shape, n_actions - 1, torch.bool]]) –

forward_masks

A boolean tensor of allowable forward policy actions.

backward_masks

A boolean tensor of allowable backward policy actions.

device :ClassVar[torch.device]
n_actions :ClassVar[int]
__getitem__(index)

Access particular states of the batch.

Parameters

index (int | Sequence[int] | Sequence[bool]) –

Return type

DiscreteStates

__setitem__(index, states)

Set particular states of the batch.

Parameters
  • index (int | Sequence[int] | Sequence[bool]) –

  • states (DiscreteStates) –

Return type

None

_check_both_forward_backward_masks_exist()
extend(other)

Concatenates to another States object along the final batch dimension.

Both States objects must have the same number of batch dimensions, which should be 1 or 2.

Parameters

other (States) – Batch of states to concatenate to.

Raises
  • ValueError – if self.batch_shape != other.batch_shape or if

  • self.batch_shape != (1,) or (2,)`

Return type

None

extend_with_sf(required_first_dim)

Extends forward and backward masks along the first batch dimension.

After extending the state along the first batch dimensions with \(s_f\) by required_first_dim, also extends both forward and backward masks with ones along the first dimension by required_first_dim.

Parameters

required_first_dim (int) – The size of the first batch dimension post-expansion.

Return type

None

flatten()

Flatten the batch dimension of the states.

Useful for example when extracting individual states from trajectories.

Return type

DiscreteStates

abstract update_masks()

Updates the masks, called after each action is taken.

Return type

None

class gfn.states.States(tensor)

Bases: abc.ABC

Base class for states, seen as nodes of the DAG.

For each environment, a States subclass is needed. A States object is a collection of multiple states (nodes of the DAG). A tensor representation of the states is required for batching. If a state is represented with a tensor of shape (*state_shape), a batch of states is represented with a States object, with the attribute tensor of shape (*batch_shape, *state_shape). Other representations are possible (e.g. state as string, numpy array, graph, etc…), but these representations cannot be batched.

If the environment’s action space is discrete (i.e. the environment subclasses DiscreteEnv), then each States object is also endowed with a forward_masks and backward_masks boolean attributes representing which actions are allowed at each state. This makes it possible to instantly access the allowed actions at each state, without having to call the environment’s validate_actions method. Put different, validate_actions for such environments, directly calls the masks. This is handled in the DiscreteSpace subclass.

A batch_shape attribute is also required, to keep track of the batch dimension. A trajectory can be represented by a States object with batch_shape = (n_states,). Multiple trajectories can be represented by a States object with batch_shape = (n_states, n_trajectories).

Because multiple trajectories can have different lengths, batching requires appending a dummy tensor to trajectories that are shorter than the longest trajectory. The dummy state is the \(s_f\) attribute of the environment (e.g. [-1, …, -1], or [-inf, …, -inf], etc…). Which is never processed, and is used to pad the batch of states only.

Parameters

tensor (torchtyping.TensorType[batch_shape, state_shape]) –

tensor

Tensor representing a batch of states.

batch_shape

Sizes of the batch dimensions.

_log_rewards

Stores the log rewards of each state.

property device: torch.device
Return type

torch.device

property is_initial_state: torchtyping.TensorType[batch_shape, torch.bool]

Return a tensor that is True for states that are \(s_0\) of the DAG.

Return type

torchtyping.TensorType[batch_shape, torch.bool]

property is_sink_state: torchtyping.TensorType[batch_shape, torch.bool]

Return a tensor that is True for states that are \(s_f\) of the DAG.

Return type

torchtyping.TensorType[batch_shape, torch.bool]

property log_rewards: torchtyping.TensorType[batch_shape, torch.float]
Return type

torchtyping.TensorType[batch_shape, torch.float]

s0 :ClassVar[torchtyping.TensorType[States.state_shape, torch.float]]
sf :ClassVar[torchtyping.TensorType[States.state_shape, torch.float]]
state_shape :ClassVar[tuple[int, Ellipsis]]
__getitem__(index)

Access particular states of the batch.

Parameters

index (int | Sequence[int] | Sequence[bool]) –

Return type

States

__len__()
__repr__()

Return repr(self).

__setitem__(index, states)

Set particular states of the batch.

Parameters
  • index (int | Sequence[int] | Sequence[bool]) –

  • states (States) –

Return type

None

compare(other)

Computes elementwise equality between state tensor with an external tensor.

Parameters

other (torchtyping.TensorType[batch_shape, state_shape, torch.float]) – Tensor of states to compare to.

Return type

torchtyping.TensorType[batch_shape, torch.bool]

Returns: Tensor of booleans indicating whether the states are equal to the

states in self.

extend(other)

Concatenates to another States object along the final batch dimension.

Both States objects must have the same number of batch dimensions, which should be 1 or 2.

Parameters

other (States) – Batch of states to concatenate to.

Raises
  • ValueError – if self.batch_shape != other.batch_shape or if

  • self.batch_shape != (1,) or (2,)`

Return type

None

extend_with_sf(required_first_dim)

Extends a 2-dimensional batch of states along the first batch dimension.

Given a batch of states (i.e. of batch_shape=(a, b)), extends a it to a States object of batch_shape = (required_first_dim, b), by adding the required number of \(s_f\) tensors. This is useful to extend trajectories of different lengths.

Parameters

required_first_dim (int) – The size of the first batch dimension post-expansion.

Return type

None

flatten()

Flatten the batch dimension of the states.

Useful for example when extracting individual states from trajectories.

Return type

States

classmethod from_batch_shape(batch_shape, random=False, sink=False)

Create a States object with the given batch shape.

By default, all states are initialized to \(s_0\), the initial state. Optionally, one can initialize random state, which requires that the environment implements the make_random_states_tensor class method. Sink can be used to initialize states at \(s_f\), the sink state. Both random and sink cannot be True at the same time.

Parameters
  • batch_shape (tuple[int]) – Shape of the batch dimensions.

  • random (optional) – Initalize states randomly.

  • sink (optional) – States initialized with s_f (the sink state).

Raises

ValueError – If both Random and Sink are True.

Return type

States

classmethod make_initial_states_tensor(batch_shape)

Makes a tensor with a batch_shape of states consisting of \(s_0`\)s.

Parameters

batch_shape (tuple[int]) –

Return type

torchtyping.TensorType[States.make_initial_states_tensor.batch_shape, state_shape, torch.float]

abstract classmethod make_random_states_tensor(batch_shape)

Makes a tensor with a batch_shape of random states, placeholder.

Parameters

batch_shape (tuple[int]) –

Return type

torchtyping.TensorType[States.make_random_states_tensor.batch_shape, state_shape, torch.float]

classmethod make_sink_states_tensor(batch_shape)

Makes a tensor with a batch_shape of states consisting of \(s_f\)s.

Parameters

batch_shape (tuple[int]) –

Return type

torchtyping.TensorType[States.make_sink_states_tensor.batch_shape, state_shape, torch.float]