gfn.envs

Subpackages

Submodules

Package Contents

Classes

DiscreteEBMEnv

Environment for discrete energy-based models, based on https://arxiv.org/pdf/2202.01361.pdf

Env

Base class for environments, showing which methods should be implemented.

HyperGrid

Base class for environments, showing which methods should be implemented.

class gfn.envs.DiscreteEBMEnv(ndim, energy=None, alpha=1.0, device_str='cpu')

Bases: gfn.envs.env.Env

Environment for discrete energy-based models, based on https://arxiv.org/pdf/2202.01361.pdf

Parameters
  • ndim (int) –

  • energy (EnergyFunction | None) –

  • alpha (float) –

  • device_str (Literal[cpu, cuda]) –

property all_states: gfn.containers.states.States

Returns a batch of all states for environments with enumerable states. The batch_shape should be (n_states,). This should satisfy: self.get_states_indices(self.all_states) == torch.arange(self.n_states)

Return type

gfn.containers.states.States

property log_partition: float

Returns the logarithm of the partition function.

Return type

float

property n_states: int
Return type

int

property n_terminating_states: int
Return type

int

property terminating_states: gfn.containers.states.States

Returns a batch of all terminating states for environments with enumerable states. The batch_shape should be (n_terminating_states,). This should satisfy: self.get_terminating_states_indices(self.terminating_states) == torch.arange(self.n_terminating_states)

Return type

gfn.containers.states.States

property true_dist_pmf: torch.Tensor

Returns a one-dimensional tensor representing the true distribution.

Return type

torch.Tensor

get_states_indices(states)

The chosen encoding is the following: -1 -> 0, 0 -> 1, 1 -> 2, then we convert to base 3

Parameters

states (gfn.containers.states.States) –

Return type

BatchTensor

get_terminating_states_indices(states)
Parameters

states (gfn.containers.states.States) –

Return type

BatchTensor

is_exit_actions(actions)

Returns True if the action is an exit action.

Parameters

actions (BatchTensor) –

Return type

BatchTensor

log_reward(final_states)

Either this or reward needs to be implemented.

Parameters

final_states (gfn.containers.states.States) –

Return type

BatchTensor

make_States_class()

Returns a class that inherits from States and implements the environment-specific methods.

Return type

type[gfn.containers.states.States]

maskless_backward_step(states, actions)

Same as the backward_step function, but without worrying whether or not the actions are valid, or masking.

Parameters
  • states (StatesTensor) –

  • actions (BatchTensor) –

Return type

None

maskless_step(states, actions)

Same as the step function, but without worrying whether or not the actions are valid, or masking.

Parameters
  • states (StatesTensor) –

  • actions (BatchTensor) –

Return type

None

class gfn.envs.Env(action_space, s0, sf=None, device_str=None, preprocessor=None)

Bases: abc.ABC

Base class for environments, showing which methods should be implemented. A common assumption for all environments is that all actions are discrete, represented by a number in {0, …, n_actions - 1}, the last one being the exit action.

Parameters
  • action_space (gymnasium.spaces.Space) –

  • s0 (OneStateTensor) –

  • sf (Optional[OneStateTensor]) –

  • device_str (Optional[str]) –

  • preprocessor (Optional[gfn.envs.preprocessors.Preprocessor]) –

property all_states: gfn.containers.states.States

Returns a batch of all states for environments with enumerable states. The batch_shape should be (n_states,). This should satisfy: self.get_states_indices(self.all_states) == torch.arange(self.n_states)

Return type

gfn.containers.states.States

property log_partition: float

Returns the logarithm of the partition function.

Return type

float

property n_actions: int
Return type

int

property n_states: int
Return type

int

property n_terminating_states: int
Return type

int

property terminating_states: gfn.containers.states.States

Returns a batch of all terminating states for environments with enumerable states. The batch_shape should be (n_terminating_states,). This should satisfy: self.get_terminating_states_indices(self.terminating_states) == torch.arange(self.n_terminating_states)

Return type

gfn.containers.states.States

property true_dist_pmf: PmfTensor

Returns a one-dimensional tensor representing the true distribution.

Return type

PmfTensor

backward_step(states, actions)

Function that takes a batch of states and actions and returns a batch of next states and a boolean tensor indicating initial states in the new batch.

Parameters
Return type

gfn.containers.states.States

get_states_indices(states)
Parameters

states (gfn.containers.states.States) –

Return type

TensorLong

get_terminating_states_indices(states)
Parameters

states (gfn.containers.states.States) –

Return type

TensorLong

abstract is_exit_actions(actions)

Returns True if the action is an exit action.

Parameters

actions (TensorLong) –

Return type

TensorBool

abstract log_reward(final_states)

Either this or reward needs to be implemented.

Parameters

final_states (gfn.containers.states.States) –

Return type

TensorFloat

abstract make_States_class()

Returns a class that inherits from States and implements the environment-specific methods.

Return type

type[gfn.containers.states.States]

abstract maskless_backward_step(states, actions)

Same as the backward_step function, but without worrying whether or not the actions are valid, or masking.

Parameters
  • states (StatesTensor) –

  • actions (TensorLong) –

Return type

None

abstract maskless_step(states, actions)

Same as the step function, but without worrying whether or not the actions are valid, or masking.

Parameters
  • states (StatesTensor) –

  • actions (TensorLong) –

Return type

None

reset(batch_shape, random=False)

Instantiates a batch of initial states.

Parameters
  • batch_shape (Union[int, Tuple[int]]) –

  • random (bool) –

Return type

gfn.containers.states.States

reward(final_states)

Either this or log_reward needs to be implemented.

Parameters

final_states (gfn.containers.states.States) –

Return type

TensorFloat

step(states, actions)

Function that takes a batch of states and actions and returns a batch of next states and a boolean tensor indicating sink states in the new batch.

Parameters
Return type

gfn.containers.states.States

class gfn.envs.HyperGrid(ndim=2, height=4, R0=0.1, R1=0.5, R2=2.0, reward_cos=False, device_str='cpu', preprocessor_name='KHot')

Bases: gfn.envs.env.Env

Base class for environments, showing which methods should be implemented. A common assumption for all environments is that all actions are discrete, represented by a number in {0, …, n_actions - 1}, the last one being the exit action.

Parameters
  • ndim (int) –

  • height (int) –

  • R0 (float) –

  • R1 (float) –

  • R2 (float) –

  • reward_cos (bool) –

  • device_str (Literal[cpu, cuda]) –

  • preprocessor_name (Literal[KHot, OneHot, Identity]) –

property all_states: gfn.containers.states.States

Returns a batch of all states for environments with enumerable states. The batch_shape should be (n_states,). This should satisfy: self.get_states_indices(self.all_states) == torch.arange(self.n_states)

Return type

gfn.containers.states.States

property log_partition: float

Returns the logarithm of the partition function.

Return type

float

property n_states: int
Return type

int

property n_terminating_states: int
Return type

int

property terminating_states: gfn.containers.states.States

Returns a batch of all terminating states for environments with enumerable states. The batch_shape should be (n_terminating_states,). This should satisfy: self.get_terminating_states_indices(self.terminating_states) == torch.arange(self.n_terminating_states)

Return type

gfn.containers.states.States

property true_dist_pmf: torch.Tensor

Returns a one-dimensional tensor representing the true distribution.

Return type

torch.Tensor

build_grid()

Utility function to build the complete grid

Return type

gfn.containers.states.States

get_states_indices(states)
Parameters

states (gfn.containers.states.States) –

Return type

TensorLong

get_terminating_states_indices(states)
Parameters

states (gfn.containers.states.States) –

Return type

TensorLong

is_exit_actions(actions)

Returns True if the action is an exit action.

Parameters

actions (TensorLong) –

Return type

TensorBool

log_reward(final_states)

Either this or reward needs to be implemented.

Parameters

final_states (gfn.containers.states.States) –

Return type

TensorFloat

make_States_class()

Creates a States class for this environment

Return type

type[gfn.containers.states.States]

maskless_backward_step(states, actions)

Same as the backward_step function, but without worrying whether or not the actions are valid, or masking.

Parameters
  • states (StatesTensor) –

  • actions (TensorLong) –

Return type

None

maskless_step(states, actions)

Same as the step function, but without worrying whether or not the actions are valid, or masking.

Parameters
  • states (StatesTensor) –

  • actions (TensorLong) –

Return type

None

true_reward(final_states)
Parameters

final_states (gfn.containers.states.States) –

Return type

TensorFloat