gfn.containers.trajectories
Module Contents
Classes
Base class for states containers (states, transitions, or trajectories) |
Attributes
- gfn.containers.trajectories.FloatTensor1D
- gfn.containers.trajectories.FloatTensor2D
- gfn.containers.trajectories.Tensor1D
- gfn.containers.trajectories.Tensor2D
- gfn.containers.trajectories.Tensor2D2
- class gfn.containers.trajectories.Trajectories(env, states=None, actions=None, when_is_done=None, is_backward=False, log_rewards=None, log_probs=None)
Bases:
gfn.containers.base.ContainerBase class for states containers (states, transitions, or trajectories)
- Parameters
env (gfn.envs.Env) –
states (States | None) –
actions (Tensor2D | None) –
when_is_done (Tensor1D | None) –
is_backward (bool) –
log_rewards (FloatTensor1D | None) –
log_probs (FloatTensor2D | None) –
- property last_states: gfn.containers.states.States
- Return type
- property log_rewards: FloatTensor1D | None
- Return type
FloatTensor1D | None
- property max_length: int
- Return type
int
- property n_trajectories: int
- Return type
int
- __getitem__(index)
Returns a subset of the n_trajectories trajectories.
- Parameters
index (int | Sequence[int]) –
- Return type
- __len__()
Returns the number of elements in the container
- Return type
int
- __repr__()
Return repr(self).
- Return type
str
- extend(other)
Extend the trajectories with another set of trajectories.
- Parameters
other (Trajectories) –
- Return type
None
- extend_actions(required_first_dim)
Extends the actions and log_probs along the first dimension by by adding -1s as necessary. This is useful for extending trajectories of different lengths.
- Parameters
required_first_dim (int) –
- Return type
None
- static revert_backward_trajectories(trajectories)
- Parameters
trajectories (Trajectories) –
- Return type
- to_non_initial_intermediary_and_terminating_states()
Returns a tuple of States objects from the trajectories, containing all non-initial intermediary and all terminating states in the trajectories
- to_states()
Returns a States object from the trajectories, containing all states in the trajectories
- Return type
- to_transitions()
Returns a Transitions object from the trajectories
- Return type