gfn.env

Module Contents

Classes

DiscreteEnv

Base class for discrete environments, where actions are represented by a number in

Env

Base class for all environments. Environments require that individual states be represented as a unique tensor of

Attributes

NonValidActionsError

class gfn.env.DiscreteEnv(n_actions, s0, sf=None, device_str=None, preprocessor=None)

Bases: Env, abc.ABC

Base class for discrete environments, where actions are represented by a number in {0, …, n_actions - 1}, the last one being the exit action. DiscreteEnv allow specifying the validity of actions (forward and backward), via mask tensors, that are directly attached to States objects.

Parameters
  • n_actions (int) –

  • s0 (torchtyping.TensorType[state_shape, torch.float]) –

  • sf (Optional[torchtyping.TensorType[state_shape, torch.float]]) –

  • device_str (Optional[str]) –

  • preprocessor (Optional[gfn.preprocessors.Preprocessor]) –

property all_states: gfn.states.DiscreteStates

Returns a batch of all states. The batch_shape should be (n_states,). This should satisfy: self.get_states_indices(self.all_states) == torch.arange(self.n_states)

Return type

gfn.states.DiscreteStates

property n_states: int
Return type

int

property n_terminating_states: int
Return type

int

property terminating_states: gfn.states.DiscreteStates

Returns a batch of all terminating states for environments with enumerable states. The batch_shape should be (n_terminating_states,). This should satisfy: self.get_terminating_states_indices(self.terminating_states) == torch.arange(self.n_terminating_states)

Return type

gfn.states.DiscreteStates

property true_dist_pmf: torchtyping.TensorType[n_states, torch.float]

Returns a one-dimensional tensor representing the true distribution.

Return type

torchtyping.TensorType[n_states, torch.float]

get_states_indices(states)
Parameters

states (gfn.states.DiscreteStates) –

Return type

torchtyping.TensorType[batch_shape, torch.long]

get_terminating_states_indices(states)
Parameters

states (gfn.states.DiscreteStates) –

Return type

torchtyping.TensorType[batch_shape, torch.long]

is_action_valid(states, actions, backward=False)

Returns True if the actions are valid in the given states.

Parameters
Return type

bool

make_Actions_class()

Returns a class that inherits from Actions and implements the environment-specific methods.

Return type

type[gfn.actions.Actions]

step(states, actions)

Function that takes a batch of states and actions and returns a batch of next states and a boolean tensor indicating sink states in the new batch.

Parameters
Return type

gfn.states.States

class gfn.env.Env(s0, sf=None, device_str=None, preprocessor=None)

Bases: abc.ABC

Base class for all environments. Environments require that individual states be represented as a unique tensor of arbitrary shape.

Parameters
  • s0 (torchtyping.TensorType[state_shape, torch.float]) –

  • sf (Optional[torchtyping.TensorType[state_shape, torch.float]]) –

  • device_str (Optional[str]) –

  • preprocessor (Optional[gfn.preprocessors.Preprocessor]) –

property log_partition: float

Returns the logarithm of the partition function.

Return type

float

backward_step(states, actions)

Function that takes a batch of states and actions and returns a batch of next states and a boolean tensor indicating initial states in the new batch.

Parameters
Return type

gfn.states.States

abstract is_action_valid(states, actions, backward=False)

Returns True if the actions are valid in the given states.

Parameters
Return type

bool

abstract log_reward(final_states)

Either this or reward needs to be implemented.

Parameters

final_states (gfn.states.States) –

Return type

torchtyping.TensorType[batch_shape, torch.float]

abstract make_Actions_class()

Returns a class that inherits from Actions and implements the environment-specific methods.

Return type

type[gfn.actions.Actions]

abstract make_States_class()

Returns a class that inherits from States and implements the environment-specific methods.

Return type

type[gfn.states.States]

abstract maskless_backward_step(states, actions)

Function that takes a batch of states and actions and returns a batch of previous states. Does not need to check whether the actions are valid or the states are sink states.

Parameters
Return type

torchtyping.TensorType[batch_shape, state_shape, torch.float]

abstract maskless_step(states, actions)

Function that takes a batch of states and actions and returns a batch of next states. Does not need to check whether the actions are valid or the states are sink states.

Parameters
Return type

torchtyping.TensorType[batch_shape, state_shape, torch.float]

reset(batch_shape=None, random=False, sink=False, seed=None)

Instantiates a batch of initial states. random and sink cannot be both True. When random is true and seed is not None, environment randomization is fixed by the submitted seed for reproducibility.

Parameters
  • batch_shape (Optional[Union[int, Tuple[int]]]) –

  • random (bool) –

  • sink (bool) –

  • seed (int) –

Return type

gfn.states.States

reward(final_states)

Either this or log_reward needs to be implemented.

Parameters

final_states (gfn.states.States) –

Return type

torchtyping.TensorType[batch_shape, torch.float]

step(states, actions)

Function that takes a batch of states and actions and returns a batch of next states and a boolean tensor indicating sink states in the new batch.

Parameters
Return type

gfn.states.States

validate_actions(states, actions, backward=False)

First, asserts that states and actions have the same batch_shape. Then, uses is_action_valid. Returns a boolean indicating whether states/actions pairs are valid.

Parameters
Return type

bool

gfn.env.NonValidActionsError