gfn.gym.box
Module Contents
Classes
Box environment, corresponding to the one in Section 4.1 of https://arxiv.org/abs/2301.12594 |
- class gfn.gym.box.Box(delta=0.1, R0=0.1, R1=0.5, R2=2.0, epsilon=0.0001, device_str='cpu')
Bases:
gfn.env.EnvBox environment, corresponding to the one in Section 4.1 of https://arxiv.org/abs/2301.12594
- Parameters
delta (float) –
R0 (float) –
R1 (float) –
R2 (float) –
epsilon (float) –
device_str (Literal[cpu, cuda]) –
- property log_partition: float
Returns the logarithm of the partition function.
- Return type
float
- is_action_valid(states, actions, backward=False)
Returns True if the actions are valid in the given states.
- Parameters
states (gfn.states.States) –
actions (gfn.actions.Actions) –
backward (bool) –
- Return type
bool
- log_reward(final_states)
Either this or reward needs to be implemented.
- Parameters
final_states (gfn.states.States) –
- Return type
torchtyping.TensorType[batch_shape, torch.float]
- make_Actions_class()
Returns a class that inherits from Actions and implements the environment-specific methods.
- Return type
type[gfn.actions.Actions]
- make_States_class()
Returns a class that inherits from States and implements the environment-specific methods.
- Return type
type[gfn.states.States]
- maskless_backward_step(states, actions)
Function that takes a batch of states and actions and returns a batch of previous states. Does not need to check whether the actions are valid or the states are sink states.
- Parameters
states (gfn.states.States) –
actions (gfn.actions.Actions) –
- Return type
torchtyping.TensorType[batch_shape, 2, torch.float]
- maskless_step(states, actions)
Function that takes a batch of states and actions and returns a batch of next states. Does not need to check whether the actions are valid or the states are sink states.
- Parameters
states (gfn.states.States) –
actions (gfn.actions.Actions) –
- Return type
torchtyping.TensorType[batch_shape, 2, torch.float]
- static norm(x)
- Parameters
x (torchtyping.TensorType[batch_shape, 2, torch.float]) –
- Return type
torch.Tensor