gfn.env
Module Contents
Classes
Base class for discrete environments, where actions are represented by a number in |
|
Base class for all environments. Environments require that individual states be represented as a unique tensor of |
Attributes
- class gfn.env.DiscreteEnv(n_actions, s0, sf=None, device_str=None, preprocessor=None)
Bases:
Env,abc.ABCBase class for discrete environments, where actions are represented by a number in {0, …, n_actions - 1}, the last one being the exit action. DiscreteEnv allow specifying the validity of actions (forward and backward), via mask tensors, that are directly attached to States objects.
- Parameters
n_actions (int) –
s0 (torchtyping.TensorType[state_shape, torch.float]) –
sf (Optional[torchtyping.TensorType[state_shape, torch.float]]) –
device_str (Optional[str]) –
preprocessor (Optional[gfn.preprocessors.Preprocessor]) –
- property all_states: gfn.states.DiscreteStates
Returns a batch of all states. The batch_shape should be (n_states,). This should satisfy: self.get_states_indices(self.all_states) == torch.arange(self.n_states)
- Return type
- property n_states: int
- Return type
int
- property n_terminating_states: int
- Return type
int
- property terminating_states: gfn.states.DiscreteStates
Returns a batch of all terminating states for environments with enumerable states. The batch_shape should be (n_terminating_states,). This should satisfy: self.get_terminating_states_indices(self.terminating_states) == torch.arange(self.n_terminating_states)
- Return type
- property true_dist_pmf: torchtyping.TensorType[n_states, torch.float]
Returns a one-dimensional tensor representing the true distribution.
- Return type
torchtyping.TensorType[n_states, torch.float]
- get_states_indices(states)
- Parameters
states (gfn.states.DiscreteStates) –
- Return type
torchtyping.TensorType[batch_shape, torch.long]
- get_terminating_states_indices(states)
- Parameters
states (gfn.states.DiscreteStates) –
- Return type
torchtyping.TensorType[batch_shape, torch.long]
- is_action_valid(states, actions, backward=False)
Returns True if the actions are valid in the given states.
- Parameters
states (gfn.states.States) –
actions (gfn.actions.Actions) –
backward (bool) –
- Return type
bool
- make_Actions_class()
Returns a class that inherits from Actions and implements the environment-specific methods.
- Return type
type[gfn.actions.Actions]
- step(states, actions)
Function that takes a batch of states and actions and returns a batch of next states and a boolean tensor indicating sink states in the new batch.
- Parameters
states (gfn.states.DiscreteStates) –
actions (gfn.actions.Actions) –
- Return type
- class gfn.env.Env(s0, sf=None, device_str=None, preprocessor=None)
Bases:
abc.ABCBase class for all environments. Environments require that individual states be represented as a unique tensor of arbitrary shape.
- Parameters
s0 (torchtyping.TensorType[state_shape, torch.float]) –
sf (Optional[torchtyping.TensorType[state_shape, torch.float]]) –
device_str (Optional[str]) –
preprocessor (Optional[gfn.preprocessors.Preprocessor]) –
- property log_partition: float
Returns the logarithm of the partition function.
- Return type
float
- backward_step(states, actions)
Function that takes a batch of states and actions and returns a batch of next states and a boolean tensor indicating initial states in the new batch.
- Parameters
states (gfn.states.States) –
actions (gfn.actions.Actions) –
- Return type
- abstract is_action_valid(states, actions, backward=False)
Returns True if the actions are valid in the given states.
- Parameters
states (gfn.states.States) –
actions (gfn.actions.Actions) –
backward (bool) –
- Return type
bool
- abstract log_reward(final_states)
Either this or reward needs to be implemented.
- Parameters
final_states (gfn.states.States) –
- Return type
torchtyping.TensorType[batch_shape, torch.float]
- abstract make_Actions_class()
Returns a class that inherits from Actions and implements the environment-specific methods.
- Return type
type[gfn.actions.Actions]
- abstract make_States_class()
Returns a class that inherits from States and implements the environment-specific methods.
- Return type
type[gfn.states.States]
- abstract maskless_backward_step(states, actions)
Function that takes a batch of states and actions and returns a batch of previous states. Does not need to check whether the actions are valid or the states are sink states.
- Parameters
states (gfn.states.States) –
actions (gfn.actions.Actions) –
- Return type
torchtyping.TensorType[batch_shape, state_shape, torch.float]
- abstract maskless_step(states, actions)
Function that takes a batch of states and actions and returns a batch of next states. Does not need to check whether the actions are valid or the states are sink states.
- Parameters
states (gfn.states.States) –
actions (gfn.actions.Actions) –
- Return type
torchtyping.TensorType[batch_shape, state_shape, torch.float]
- reset(batch_shape=None, random=False, sink=False, seed=None)
Instantiates a batch of initial states. random and sink cannot be both True. When random is true and seed is not None, environment randomization is fixed by the submitted seed for reproducibility.
- Parameters
batch_shape (Optional[Union[int, Tuple[int]]]) –
random (bool) –
sink (bool) –
seed (int) –
- Return type
- reward(final_states)
Either this or log_reward needs to be implemented.
- Parameters
final_states (gfn.states.States) –
- Return type
torchtyping.TensorType[batch_shape, torch.float]
- step(states, actions)
Function that takes a batch of states and actions and returns a batch of next states and a boolean tensor indicating sink states in the new batch.
- Parameters
states (gfn.states.States) –
actions (gfn.actions.Actions) –
- Return type
- validate_actions(states, actions, backward=False)
First, asserts that states and actions have the same batch_shape. Then, uses is_action_valid. Returns a boolean indicating whether states/actions pairs are valid.
- Parameters
states (gfn.states.States) –
actions (gfn.actions.Actions) –
backward (bool) –
- Return type
bool
- gfn.env.NonValidActionsError