gfn.gym.discrete_ebm
Module Contents
Classes
Environment for discrete energy-based models, based on https://arxiv.org/pdf/2202.01361.pdf |
|
Base class for energy functions |
|
Ising model energy function |
- class gfn.gym.discrete_ebm.DiscreteEBM(ndim, energy=None, alpha=1.0, device_str='cpu', preprocessor_name='Identity')
Bases:
gfn.env.DiscreteEnvEnvironment for discrete energy-based models, based on https://arxiv.org/pdf/2202.01361.pdf
- Parameters
ndim (int) –
energy (EnergyFunction | None) –
alpha (float) –
device_str (Literal[cpu, cuda]) –
preprocessor_name (Literal[Identity, Enum]) –
- property all_states: gfn.states.DiscreteStates
Returns a batch of all states. The batch_shape should be (n_states,). This should satisfy: self.get_states_indices(self.all_states) == torch.arange(self.n_states)
- Return type
- property log_partition: float
Returns the logarithm of the partition function.
- Return type
float
- property n_states: int
- Return type
int
- property n_terminating_states: int
- Return type
int
- property terminating_states: gfn.states.DiscreteStates
Returns a batch of all terminating states for environments with enumerable states. The batch_shape should be (n_terminating_states,). This should satisfy: self.get_terminating_states_indices(self.terminating_states) == torch.arange(self.n_terminating_states)
- Return type
- property true_dist_pmf: torch.Tensor
Returns a one-dimensional tensor representing the true distribution.
- Return type
torch.Tensor
- get_states_indices(states)
The chosen encoding is the following: -1 -> 0, 0 -> 1, 1 -> 2, then we convert to base 3
- Parameters
states (gfn.states.DiscreteStates) –
- Return type
torchtyping.TensorType[batch_shape]
- get_terminating_states_indices(states)
- Parameters
states (gfn.states.DiscreteStates) –
- Return type
torchtyping.TensorType[batch_shape]
- is_exit_actions(actions)
- Parameters
actions (torchtyping.TensorType[batch_shape]) –
- Return type
torchtyping.TensorType[batch_shape]
- log_reward(final_states)
Either this or reward needs to be implemented.
- Parameters
final_states (gfn.states.DiscreteStates) –
- Return type
torchtyping.TensorType[batch_shape]
- make_States_class()
Returns a class that inherits from States and implements the environment-specific methods.
- Return type
- maskless_backward_step(states, actions)
Function that takes a batch of states and actions and returns a batch of previous states. Does not need to check whether the actions are valid or the states are sink states.
- Parameters
states (gfn.states.States) –
actions (gfn.actions.Actions) –
- Return type
torchtyping.TensorType[batch_shape, state_shape, torch.float]
- maskless_step(states, actions)
Function that takes a batch of states and actions and returns a batch of next states. Does not need to check whether the actions are valid or the states are sink states.
- Parameters
states (gfn.states.States) –
actions (gfn.actions.Actions) –
- Return type
torchtyping.TensorType[batch_shape, state_shape, torch.float]
- class gfn.gym.discrete_ebm.EnergyFunction
Bases:
torch.nn.Module,abc.ABCBase class for energy functions
- abstract forward(states)
- Parameters
states (torchtyping.TensorType[batch_shape, state_shape, torch.float]) –
- Return type
torchtyping.TensorType[batch_shape]
- class gfn.gym.discrete_ebm.IsingModel(J)
Bases:
EnergyFunctionIsing model energy function
- Parameters
J (torchtyping.TensorType[state_shape, state_shape, torch.float]) –
- forward(states)
- Parameters
states (torchtyping.TensorType[batch_shape, state_shape, torch.float]) –
- Return type
torchtyping.TensorType[batch_shape]