gfn.containers.states
Module Contents
Classes
Base class for states, seen as nodes of the DAG. |
Functions
|
Casts the given masks to the correct type, if they are not None. |
Attributes
- gfn.containers.states.BackwardMasksTensor
- gfn.containers.states.DonesTensor
- gfn.containers.states.ForwardMasksTensor
- gfn.containers.states.OneStateTensor
- gfn.containers.states.RewardsTensor
- class gfn.containers.states.States(states_tensor, forward_masks=None, backward_masks=None)
Bases:
gfn.containers.base.Container,abc.ABCBase class for states, seen as nodes of the DAG. For each environment, a States subclass is needed. A States object is a collection of multiple states (nodes of the DAG). A tensor representation of the states is required for batching. If a state is represented with a tensor of shape (*state_shape), a batch of states is represented with a States object, with the attribute states_tensor of shape (*batch_shape, *state_shape). Other representations are possible (e.g. state as string, as numpy array, as graph, etc…), but these representations should not be batched.
If the environment’s action space is discrete, then each States object is also endowed with a forward_masks and backward_masks boolean attributes representing which actions are allowed at each state.
A batch_shape attribute is also required, to keep track of the batch dimension. A trajectory can be represented by a States object with batch_shape = (n_states,). Multiple trajectories can be represented by a States object with batch_shape = (n_states, n_trajectories).
Because multiple trajectories can have different lengths, batching requires appending a dummy tensor to trajectories that are shorter than the longest trajectory. The dummy state is the s_f attribute of the environment (e.g. [-1, …, -1], or [-inf, …, -inf], etc…). Which is never processed, and is used to pad the batch of states only.
- Parameters
states_tensor (StatesTensor) –
forward_masks (ForwardMasksTensor | None) –
backward_masks (BackwardMasksTensor | None) –
- property device: torch.device
- Return type
torch.device
- property is_initial_state: DonesTensor
Return a boolean tensor of shape=(*batch_shape,), where True means that the state is \(s_0\) of the DAG.
- Return type
DonesTensor
- property is_sink_state: DonesTensor
Return a boolean tensor of shape=(*batch_shape,), where True means that the state is \(s_f\) of the DAG.
- Return type
DonesTensor
- property log_rewards: RewardsTensor
- Return type
RewardsTensor
- s0 :ClassVar[OneStateTensor]
- sf :ClassVar[OneStateTensor]
- state_shape :ClassVar[tuple[int, Ellipsis]]
- __getitem__(index)
Access particular states of the batch.
- Parameters
index (int | Sequence[int] | Sequence[bool]) –
- Return type
- __len__()
Returns the number of elements in the container
- __repr__()
Return repr(self).
- compare(other)
Given a tensor of states, returns a tensor of booleans indicating whether the states are equal to the states in self.
- Parameters
other (StatesTensor) – Tensor of states to compare to.
- Returns
Tensor of booleans indicating whether the states are equal to the states in self.
- Return type
DonesTensor
- extend(other)
Collates to another States object of the same batch shape, which should be 1 or 2.
- Parameters
other (States) – Batch of states to collate to.
- Raises
ValueError – if self.batch_shape != other.batch_shape or if self.batch_shape != (1,) or (2,)
- Return type
None
- extend_with_sf(required_first_dim)
Takes a two-dimensional batch of states (i.e. of batch_shape (a, b)), and extends it to a States object of batch_shape (required_first_dim, b), by adding the required number of s_f tensors. This is useful to extend trajectories of different lengths.
- Parameters
required_first_dim (int) –
- Return type
None
- flatten()
Flatten the batch dimension of the states. This is useful for example when extracting individual states from trajectories.
- Return type
- classmethod from_batch_shape(batch_shape, random=False)
Create a States object with the given batch shape, all initialized to s_0. If random is True, the states are initialized randomly. This requires that the environment implements the make_random_states_tensor class method.
- Parameters
batch_shape (tuple[int]) –
random (bool) –
- Return type
- classmethod make_initial_states_tensor(batch_shape)
- Parameters
batch_shape (tuple[int]) –
- Return type
StatesTensor
- make_masks()
Create the forward and backward masks for the states. This method is called only if the masks are not provided at initialization.
- Return type
tuple[ForwardMasksTensor, BackwardMasksTensor]
- abstract classmethod make_random_states_tensor(batch_shape)
- Parameters
batch_shape (tuple[int]) –
- Return type
StatesTensor
- update_masks()
Update the masks, if necessary. This method should be called after each action is taken.
- Return type
None
- gfn.containers.states.StatesTensor
- gfn.containers.states.correct_cast(forward_masks, backward_masks)
Casts the given masks to the correct type, if they are not None. This function is to help with type checking only.
- Parameters
forward_masks (ForwardMasksTensor | None) –
backward_masks (BackwardMasksTensor | None) –
- Return type
tuple[ForwardMasksTensor, BackwardMasksTensor]