gfn.containers.transitions
Module Contents
Classes
Base class for states containers (states, transitions, or trajectories) |
Attributes
- gfn.containers.transitions.BoolTensor
- gfn.containers.transitions.FloatTensor
- gfn.containers.transitions.LongTensor
- gfn.containers.transitions.PairFloatTensor
- class gfn.containers.transitions.Transitions(env, states=None, actions=None, is_done=None, next_states=None, is_backward=False, log_rewards=None, log_probs=None)
Bases:
gfn.containers.base.ContainerBase class for states containers (states, transitions, or trajectories)
- Parameters
env (gfn.envs.Env) –
states (States | None) –
actions (LongTensor | None) –
is_done (BoolTensor | None) –
next_states (States | None) –
is_backward (bool) –
log_rewards (FloatTensor | None) –
log_probs (FloatTensor | None) –
- property all_log_rewards: PairFloatTensor
This is applicable to environments where all states are terminating. This function evaluates the rewards for all transitions that do not end in the sink state. This is useful for the Modified Detailed Balance loss.
- Return type
PairFloatTensor
- property last_states: gfn.containers.states.States
Get the last states, i.e. terminating states
- Return type
- property log_rewards: FloatTensor | None
- Return type
FloatTensor | None
- property n_transitions: int
- Return type
int
- __getitem__(index)
Access particular transitions of the batch.
- Parameters
index (int | Sequence[int]) –
- Return type
- __len__()
Returns the number of elements in the container
- Return type
int
- __repr__()
Return repr(self).
- extend(other)
Extend the Transitions object with another Transitions object.
- Parameters
other (Transitions) –
- Return type
None