gfn.gym.hypergrid
Copied and Adapted from https://github.com/Tikquuss/GflowNets_Tutorial
Module Contents
Classes
Base class for discrete environments, where actions are represented by a number in |
- class gfn.gym.hypergrid.HyperGrid(ndim=2, height=4, R0=0.1, R1=0.5, R2=2.0, reward_cos=False, device_str='cpu', preprocessor_name='KHot')
Bases:
gfn.env.DiscreteEnvBase class for discrete environments, where actions are represented by a number in {0, …, n_actions - 1}, the last one being the exit action. DiscreteEnv allow specifying the validity of actions (forward and backward), via mask tensors, that are directly attached to States objects.
- Parameters
ndim (int) –
height (int) –
R0 (float) –
R1 (float) –
R2 (float) –
reward_cos (bool) –
device_str (Literal[cpu, cuda]) –
preprocessor_name (Literal[KHot, OneHot, Identity, Enum]) –
- property all_states: gfn.states.DiscreteStates
Returns a batch of all states. The batch_shape should be (n_states,). This should satisfy: self.get_states_indices(self.all_states) == torch.arange(self.n_states)
- Return type
- property log_partition: float
Returns the logarithm of the partition function.
- Return type
float
- property n_states: int
- Return type
int
- property n_terminating_states: int
- Return type
int
- property terminating_states: gfn.states.DiscreteStates
Returns a batch of all terminating states for environments with enumerable states. The batch_shape should be (n_terminating_states,). This should satisfy: self.get_terminating_states_indices(self.terminating_states) == torch.arange(self.n_terminating_states)
- Return type
- property true_dist_pmf: torch.Tensor
Returns a one-dimensional tensor representing the true distribution.
- Return type
torch.Tensor
- build_grid()
Utility function to build the complete grid
- Return type
- get_states_indices(states)
- Parameters
states (gfn.states.DiscreteStates) –
- Return type
torchtyping.TensorType[batch_shape, torch.long]
- get_terminating_states_indices(states)
- Parameters
states (gfn.states.DiscreteStates) –
- Return type
torchtyping.TensorType[batch_shape, torch.long]
- log_reward(final_states)
Either this or reward needs to be implemented.
- Parameters
final_states (gfn.states.DiscreteStates) –
- Return type
torchtyping.TensorType[batch_shape, torch.float]
- make_States_class()
Creates a States class for this environment
- Return type
- maskless_backward_step(states, actions)
Function that takes a batch of states and actions and returns a batch of previous states. Does not need to check whether the actions are valid or the states are sink states.
- Parameters
states (gfn.states.DiscreteStates) –
actions (gfn.actions.Actions) –
- Return type
torchtyping.TensorType[batch_shape, state_shape, torch.float]
- maskless_step(states, actions)
Function that takes a batch of states and actions and returns a batch of next states. Does not need to check whether the actions are valid or the states are sink states.
- Parameters
states (gfn.states.DiscreteStates) –
actions (gfn.actions.Actions) –
- Return type
torchtyping.TensorType[batch_shape, state_shape, torch.float]
- true_reward(final_states)
In the normal setting, the reward is: R(s) = R_0 + 0.5 prod_{d=1}^D mathbf{1} left( leftlvert frac{s^d}{H-1}
0.5 rightrvert in (0.25, 0.5] right)
2 prod_{d=1}^D mathbf{1} left( leftlvert frac{s^d}{H-1} - 0.5 rightrvert in (0.3, 0.4) right)
- Parameters
final_states (gfn.states.DiscreteStates) –
- Return type
torchtyping.TensorType[batch_shape, torch.float]