gfn.gym

Subpackages

Submodules

Package Contents

Classes

Box

Box environment, corresponding to the one in Section 4.1 of https://arxiv.org/abs/2301.12594

DiscreteEBM

Environment for discrete energy-based models, based on https://arxiv.org/pdf/2202.01361.pdf

HyperGrid

Base class for discrete environments, where actions are represented by a number in

class gfn.gym.Box(delta=0.1, R0=0.1, R1=0.5, R2=2.0, epsilon=0.0001, device_str='cpu')

Bases: gfn.env.Env

Box environment, corresponding to the one in Section 4.1 of https://arxiv.org/abs/2301.12594

Parameters
  • delta (float) –

  • R0 (float) –

  • R1 (float) –

  • R2 (float) –

  • epsilon (float) –

  • device_str (Literal[cpu, cuda]) –

property log_partition: float

Returns the logarithm of the partition function.

Return type

float

is_action_valid(states, actions, backward=False)

Returns True if the actions are valid in the given states.

Parameters
Return type

bool

log_reward(final_states)

Either this or reward needs to be implemented.

Parameters

final_states (gfn.states.States) –

Return type

torchtyping.TensorType[batch_shape, torch.float]

make_Actions_class()

Returns a class that inherits from Actions and implements the environment-specific methods.

Return type

type[gfn.actions.Actions]

make_States_class()

Returns a class that inherits from States and implements the environment-specific methods.

Return type

type[gfn.states.States]

maskless_backward_step(states, actions)

Function that takes a batch of states and actions and returns a batch of previous states. Does not need to check whether the actions are valid or the states are sink states.

Parameters
Return type

torchtyping.TensorType[batch_shape, 2, torch.float]

maskless_step(states, actions)

Function that takes a batch of states and actions and returns a batch of next states. Does not need to check whether the actions are valid or the states are sink states.

Parameters
Return type

torchtyping.TensorType[batch_shape, 2, torch.float]

static norm(x)
Parameters

x (torchtyping.TensorType[batch_shape, 2, torch.float]) –

Return type

torch.Tensor

class gfn.gym.DiscreteEBM(ndim, energy=None, alpha=1.0, device_str='cpu', preprocessor_name='Identity')

Bases: gfn.env.DiscreteEnv

Environment for discrete energy-based models, based on https://arxiv.org/pdf/2202.01361.pdf

Parameters
  • ndim (int) –

  • energy (EnergyFunction | None) –

  • alpha (float) –

  • device_str (Literal[cpu, cuda]) –

  • preprocessor_name (Literal[Identity, Enum]) –

property all_states: gfn.states.DiscreteStates

Returns a batch of all states. The batch_shape should be (n_states,). This should satisfy: self.get_states_indices(self.all_states) == torch.arange(self.n_states)

Return type

gfn.states.DiscreteStates

property log_partition: float

Returns the logarithm of the partition function.

Return type

float

property n_states: int
Return type

int

property n_terminating_states: int
Return type

int

property terminating_states: gfn.states.DiscreteStates

Returns a batch of all terminating states for environments with enumerable states. The batch_shape should be (n_terminating_states,). This should satisfy: self.get_terminating_states_indices(self.terminating_states) == torch.arange(self.n_terminating_states)

Return type

gfn.states.DiscreteStates

property true_dist_pmf: torch.Tensor

Returns a one-dimensional tensor representing the true distribution.

Return type

torch.Tensor

get_states_indices(states)

The chosen encoding is the following: -1 -> 0, 0 -> 1, 1 -> 2, then we convert to base 3

Parameters

states (gfn.states.DiscreteStates) –

Return type

torchtyping.TensorType[batch_shape]

get_terminating_states_indices(states)
Parameters

states (gfn.states.DiscreteStates) –

Return type

torchtyping.TensorType[batch_shape]

is_exit_actions(actions)
Parameters

actions (torchtyping.TensorType[batch_shape]) –

Return type

torchtyping.TensorType[batch_shape]

log_reward(final_states)

Either this or reward needs to be implemented.

Parameters

final_states (gfn.states.DiscreteStates) –

Return type

torchtyping.TensorType[batch_shape]

make_States_class()

Returns a class that inherits from States and implements the environment-specific methods.

Return type

type[gfn.states.DiscreteStates]

maskless_backward_step(states, actions)

Function that takes a batch of states and actions and returns a batch of previous states. Does not need to check whether the actions are valid or the states are sink states.

Parameters
Return type

torchtyping.TensorType[batch_shape, state_shape, torch.float]

maskless_step(states, actions)

Function that takes a batch of states and actions and returns a batch of next states. Does not need to check whether the actions are valid or the states are sink states.

Parameters
Return type

torchtyping.TensorType[batch_shape, state_shape, torch.float]

class gfn.gym.HyperGrid(ndim=2, height=4, R0=0.1, R1=0.5, R2=2.0, reward_cos=False, device_str='cpu', preprocessor_name='KHot')

Bases: gfn.env.DiscreteEnv

Base class for discrete environments, where actions are represented by a number in {0, …, n_actions - 1}, the last one being the exit action. DiscreteEnv allow specifying the validity of actions (forward and backward), via mask tensors, that are directly attached to States objects.

Parameters
  • ndim (int) –

  • height (int) –

  • R0 (float) –

  • R1 (float) –

  • R2 (float) –

  • reward_cos (bool) –

  • device_str (Literal[cpu, cuda]) –

  • preprocessor_name (Literal[KHot, OneHot, Identity, Enum]) –

property all_states: gfn.states.DiscreteStates

Returns a batch of all states. The batch_shape should be (n_states,). This should satisfy: self.get_states_indices(self.all_states) == torch.arange(self.n_states)

Return type

gfn.states.DiscreteStates

property log_partition: float

Returns the logarithm of the partition function.

Return type

float

property n_states: int
Return type

int

property n_terminating_states: int
Return type

int

property terminating_states: gfn.states.DiscreteStates

Returns a batch of all terminating states for environments with enumerable states. The batch_shape should be (n_terminating_states,). This should satisfy: self.get_terminating_states_indices(self.terminating_states) == torch.arange(self.n_terminating_states)

Return type

gfn.states.DiscreteStates

property true_dist_pmf: torch.Tensor

Returns a one-dimensional tensor representing the true distribution.

Return type

torch.Tensor

build_grid()

Utility function to build the complete grid

Return type

gfn.states.DiscreteStates

get_states_indices(states)
Parameters

states (gfn.states.DiscreteStates) –

Return type

torchtyping.TensorType[batch_shape, torch.long]

get_terminating_states_indices(states)
Parameters

states (gfn.states.DiscreteStates) –

Return type

torchtyping.TensorType[batch_shape, torch.long]

log_reward(final_states)

Either this or reward needs to be implemented.

Parameters

final_states (gfn.states.DiscreteStates) –

Return type

torchtyping.TensorType[batch_shape, torch.float]

make_States_class()

Creates a States class for this environment

Return type

type[gfn.states.DiscreteStates]

maskless_backward_step(states, actions)

Function that takes a batch of states and actions and returns a batch of previous states. Does not need to check whether the actions are valid or the states are sink states.

Parameters
Return type

torchtyping.TensorType[batch_shape, state_shape, torch.float]

maskless_step(states, actions)

Function that takes a batch of states and actions and returns a batch of next states. Does not need to check whether the actions are valid or the states are sink states.

Parameters
Return type

torchtyping.TensorType[batch_shape, state_shape, torch.float]

true_reward(final_states)

In the normal setting, the reward is: R(s) = R_0 + 0.5 prod_{d=1}^D mathbf{1} left( leftlvert frac{s^d}{H-1}

  • 0.5 rightrvert in (0.25, 0.5] right)

  • 2 prod_{d=1}^D mathbf{1} left( leftlvert frac{s^d}{H-1} - 0.5 rightrvert in (0.3, 0.4) right)

Parameters

final_states (gfn.states.DiscreteStates) –

Return type

torchtyping.TensorType[batch_shape, torch.float]