gfn.samplers
Submodules
Package Contents
Classes
Base class for action sampling methods. |
|
For sampling backward actions in discrete environments. |
|
For Discrete environments. |
|
- class gfn.samplers.ActionsSampler
Bases:
abc.ABCBase class for action sampling methods.
- class gfn.samplers.BackwardDiscreteActionsSampler(estimator, temperature=1.0, epsilon=0.0)
Bases:
DiscreteActionsSampler,BackwardActionsSamplerFor sampling backward actions in discrete environments.
- Parameters
estimator (gfn.estimators.LogitPBEstimator) –
temperature (float) –
epsilon (float) –
- get_logits(states)
Transforms the raw logits by masking illegal actions.
- Raises
ValueError – if one of the resulting logits is NaN.
- Returns
A 2D tensor of shape (batch_size, n_actions) containing the transformed logits.
- Return type
Tensor2D
- Parameters
states (gfn.containers.states.States) –
- get_probs(states)
- Returns
The probabilities of each action in each state in the batch.
- Parameters
states (gfn.containers.states.States) –
- Return type
Tensor2D
- class gfn.samplers.DiscreteActionsSampler(estimator, temperature=1.0, sf_bias=0.0, epsilon=0.0)
Bases:
ActionsSamplerFor Discrete environments.
- Parameters
estimator (LogitPFEstimator | LogEdgeFlowEstimator) –
temperature (float) –
sf_bias (float) –
epsilon (float) –
- get_logits(states)
Transforms the raw logits by masking illegal actions.
- Raises
ValueError – if one of the resulting logits is NaN.
- Returns
A 2D tensor of shape (batch_size, n_actions) containing the transformed logits.
- Return type
Tensor2D
- Parameters
states (gfn.containers.states.States) –
- get_probs(states)
- Returns
The probabilities of each action in each state in the batch.
- Parameters
states (gfn.containers.states.States) –
- Return type
Tensor2D
- get_raw_logits(states)
This is before illegal actions are masked out and the exit action is biased. Should be used for Discrete action spaces only.
- Returns
A 2D tensor of shape (batch_size, n_actions) containing the logits for each action in each state in the batch.
- Return type
Tensor2D
- Parameters
states (gfn.containers.states.States) –
- class gfn.samplers.TrajectoriesSampler(env, actions_sampler)
- Parameters
env (gfn.envs.Env) –
actions_sampler (gfn.samplers.actions_samplers.ActionsSampler) –
- sample(n_trajectories)
- Parameters
n_trajectories (int) –
- Return type
- sample_trajectories(states=None, n_trajectories=None)
- Parameters
states (Optional[gfn.containers.States]) –
n_trajectories (Optional[int]) –
- Return type