gfn.containers.transitions

Module Contents

Classes

Transitions

Container for the transitions.

class gfn.containers.transitions.Transitions(env, states=None, actions=None, is_done=None, next_states=None, is_backward=False, log_rewards=None, log_probs=None)

Bases: gfn.containers.base.Container

Container for the transitions.

Parameters
  • env (gfn.env.Env) –

  • states (States | None) –

  • actions (Actions | None) –

  • is_done (TT['n_transitions', torch.bool] | None) –

  • next_states (States | None) –

  • is_backward (bool) –

  • log_rewards (TT['n_transitions', torch.float] | None) –

  • log_probs (TT['n_transitions', torch.float] | None) –

env

environment.

is_backward

Whether the transitions are backward transitions (i.e. next_states is the parent of states).

states

States object with uni-dimensional batch_shape, representing the parents of the transitions.

actions

Actions chosen at the parents of each transitions.

is_done

Whether the action is the exit action.

next_states

States object with uni-dimensional batch_shape, representing the children of the transitions.

log_probs

The log-probabilities of the actions.

property all_log_rewards: torchtyping.TensorType[n_transitions, 2, torch.float]

Calculate all log rewards for the transitions.

This is applicable to environments where all states are terminating. This function evaluates the rewards for all transitions that do not end in the sink state. This is useful for the Modified Detailed Balance loss.

Raises

NotImplementedError – when used for backward transitions.

Return type

torchtyping.TensorType[n_transitions, 2, torch.float]

property last_states: gfn.states.States

Get the last states, i.e. terminating states

Return type

gfn.states.States

property log_rewards: TT['n_transitions', torch.float] | None
Return type

TT[‘n_transitions’, torch.float] | None

property n_transitions: int
Return type

int

__getitem__(index)

Access particular transitions of the batch.

Parameters

index (int | Sequence[int]) –

Return type

Transitions

__len__()
Return type

int

__repr__()
extend(other)

Extend the Transitions object with another Transitions object.

Parameters

other (Transitions) –

Return type

None