gfn.containers
Submodules
Package Contents
Classes
Base class for states, seen as nodes of the DAG. |
|
Base class for states containers (states, transitions, or trajectories) |
|
Base class for states containers (states, transitions, or trajectories) |
- class gfn.containers.States(states_tensor, forward_masks=None, backward_masks=None)
Bases:
gfn.containers.base.Container,abc.ABCBase class for states, seen as nodes of the DAG. For each environment, a States subclass is needed. A States object is a collection of multiple states (nodes of the DAG). A tensor representation of the states is required for batching. If a state is represented with a tensor of shape (*state_shape), a batch of states is represented with a States object, with the attribute states_tensor of shape (*batch_shape, *state_shape). Other representations are possible (e.g. state as string, as numpy array, as graph, etc…), but these representations should not be batched.
If the environment’s action space is discrete, then each States object is also endowed with a forward_masks and backward_masks boolean attributes representing which actions are allowed at each state.
A batch_shape attribute is also required, to keep track of the batch dimension. A trajectory can be represented by a States object with batch_shape = (n_states,). Multiple trajectories can be represented by a States object with batch_shape = (n_states, n_trajectories).
Because multiple trajectories can have different lengths, batching requires appending a dummy tensor to trajectories that are shorter than the longest trajectory. The dummy state is the s_f attribute of the environment (e.g. [-1, …, -1], or [-inf, …, -inf], etc…). Which is never processed, and is used to pad the batch of states only.
- Parameters
states_tensor (StatesTensor) –
forward_masks (ForwardMasksTensor | None) –
backward_masks (BackwardMasksTensor | None) –
- property device: torch.device
- Return type
torch.device
- property is_initial_state: DonesTensor
Return a boolean tensor of shape=(*batch_shape,), where True means that the state is \(s_0\) of the DAG.
- Return type
DonesTensor
- property is_sink_state: DonesTensor
Return a boolean tensor of shape=(*batch_shape,), where True means that the state is \(s_f\) of the DAG.
- Return type
DonesTensor
- property log_rewards: RewardsTensor
- Return type
RewardsTensor
- s0 :ClassVar[OneStateTensor]
- sf :ClassVar[OneStateTensor]
- state_shape :ClassVar[tuple[int, Ellipsis]]
- __getitem__(index)
Access particular states of the batch.
- Parameters
index (int | Sequence[int] | Sequence[bool]) –
- Return type
- __len__()
Returns the number of elements in the container
- __repr__()
Return repr(self).
- compare(other)
Given a tensor of states, returns a tensor of booleans indicating whether the states are equal to the states in self.
- Parameters
other (StatesTensor) – Tensor of states to compare to.
- Returns
Tensor of booleans indicating whether the states are equal to the states in self.
- Return type
DonesTensor
- extend(other)
Collates to another States object of the same batch shape, which should be 1 or 2.
- Parameters
other (States) – Batch of states to collate to.
- Raises
ValueError – if self.batch_shape != other.batch_shape or if self.batch_shape != (1,) or (2,)
- Return type
None
- extend_with_sf(required_first_dim)
Takes a two-dimensional batch of states (i.e. of batch_shape (a, b)), and extends it to a States object of batch_shape (required_first_dim, b), by adding the required number of s_f tensors. This is useful to extend trajectories of different lengths.
- Parameters
required_first_dim (int) –
- Return type
None
- flatten()
Flatten the batch dimension of the states. This is useful for example when extracting individual states from trajectories.
- Return type
- classmethod from_batch_shape(batch_shape, random=False)
Create a States object with the given batch shape, all initialized to s_0. If random is True, the states are initialized randomly. This requires that the environment implements the make_random_states_tensor class method.
- Parameters
batch_shape (tuple[int]) –
random (bool) –
- Return type
- classmethod make_initial_states_tensor(batch_shape)
- Parameters
batch_shape (tuple[int]) –
- Return type
StatesTensor
- make_masks()
Create the forward and backward masks for the states. This method is called only if the masks are not provided at initialization.
- Return type
tuple[ForwardMasksTensor, BackwardMasksTensor]
- abstract classmethod make_random_states_tensor(batch_shape)
- Parameters
batch_shape (tuple[int]) –
- Return type
StatesTensor
- update_masks()
Update the masks, if necessary. This method should be called after each action is taken.
- Return type
None
- class gfn.containers.Trajectories(env, states=None, actions=None, when_is_done=None, is_backward=False, log_rewards=None, log_probs=None)
Bases:
gfn.containers.base.ContainerBase class for states containers (states, transitions, or trajectories)
- Parameters
env (gfn.envs.Env) –
states (States | None) –
actions (Tensor2D | None) –
when_is_done (Tensor1D | None) –
is_backward (bool) –
log_rewards (FloatTensor1D | None) –
log_probs (FloatTensor2D | None) –
- property last_states: gfn.containers.states.States
- Return type
- property log_rewards: FloatTensor1D | None
- Return type
FloatTensor1D | None
- property max_length: int
- Return type
int
- property n_trajectories: int
- Return type
int
- __getitem__(index)
Returns a subset of the n_trajectories trajectories.
- Parameters
index (int | Sequence[int]) –
- Return type
- __len__()
Returns the number of elements in the container
- Return type
int
- __repr__()
Return repr(self).
- Return type
str
- extend(other)
Extend the trajectories with another set of trajectories.
- Parameters
other (Trajectories) –
- Return type
None
- extend_actions(required_first_dim)
Extends the actions and log_probs along the first dimension by by adding -1s as necessary. This is useful for extending trajectories of different lengths.
- Parameters
required_first_dim (int) –
- Return type
None
- static revert_backward_trajectories(trajectories)
- Parameters
trajectories (Trajectories) –
- Return type
- to_non_initial_intermediary_and_terminating_states()
Returns a tuple of States objects from the trajectories, containing all non-initial intermediary and all terminating states in the trajectories
- to_states()
Returns a States object from the trajectories, containing all states in the trajectories
- Return type
- to_transitions()
Returns a Transitions object from the trajectories
- Return type
- class gfn.containers.Transitions(env, states=None, actions=None, is_done=None, next_states=None, is_backward=False, log_rewards=None, log_probs=None)
Bases:
gfn.containers.base.ContainerBase class for states containers (states, transitions, or trajectories)
- Parameters
env (gfn.envs.Env) –
states (States | None) –
actions (LongTensor | None) –
is_done (BoolTensor | None) –
next_states (States | None) –
is_backward (bool) –
log_rewards (FloatTensor | None) –
log_probs (FloatTensor | None) –
- property all_log_rewards: PairFloatTensor
This is applicable to environments where all states are terminating. This function evaluates the rewards for all transitions that do not end in the sink state. This is useful for the Modified Detailed Balance loss.
- Return type
PairFloatTensor
- property last_states: gfn.containers.states.States
Get the last states, i.e. terminating states
- Return type
- property log_rewards: FloatTensor | None
- Return type
FloatTensor | None
- property n_transitions: int
- Return type
int
- __getitem__(index)
Access particular transitions of the batch.
- Parameters
index (int | Sequence[int]) –
- Return type
- __len__()
Returns the number of elements in the container
- Return type
int
- __repr__()
Return repr(self).
- extend(other)
Extend the Transitions object with another Transitions object.
- Parameters
other (Transitions) –
- Return type
None