gfn.containers.transitions

Module Contents

Classes

Transitions

Base class for states containers (states, transitions, or trajectories)

Attributes

BoolTensor

FloatTensor

LongTensor

PairFloatTensor

gfn.containers.transitions.BoolTensor
gfn.containers.transitions.FloatTensor
gfn.containers.transitions.LongTensor
gfn.containers.transitions.PairFloatTensor
class gfn.containers.transitions.Transitions(env, states=None, actions=None, is_done=None, next_states=None, is_backward=False, log_rewards=None, log_probs=None)

Bases: gfn.containers.base.Container

Base class for states containers (states, transitions, or trajectories)

Parameters
  • env (gfn.envs.Env) –

  • states (States | None) –

  • actions (LongTensor | None) –

  • is_done (BoolTensor | None) –

  • next_states (States | None) –

  • is_backward (bool) –

  • log_rewards (FloatTensor | None) –

  • log_probs (FloatTensor | None) –

property all_log_rewards: PairFloatTensor

This is applicable to environments where all states are terminating. This function evaluates the rewards for all transitions that do not end in the sink state. This is useful for the Modified Detailed Balance loss.

Return type

PairFloatTensor

property last_states: gfn.containers.states.States

Get the last states, i.e. terminating states

Return type

gfn.containers.states.States

property log_rewards: FloatTensor | None
Return type

FloatTensor | None

property n_transitions: int
Return type

int

__getitem__(index)

Access particular transitions of the batch.

Parameters

index (int | Sequence[int]) –

Return type

Transitions

__len__()

Returns the number of elements in the container

Return type

int

__repr__()

Return repr(self).

extend(other)

Extend the Transitions object with another Transitions object.

Parameters

other (Transitions) –

Return type

None