gfn.samplers
Module Contents
Classes
`Sampler is a container for a PolicyEstimator. |
- class gfn.samplers.Sampler(estimator, **probability_distribution_kwargs)
`Sampler is a container for a PolicyEstimator.
Can be used to either sample individual actions, sample trajectories from \(s_0\), or complete a batch of partially-completed trajectories from a given batch states.
- Parameters
estimator (gfn.modules.GFNModule) –
probability_distribution_kwargs (Optional[dict]) –
- estimator
the submitted PolicyEstimator.
- probability_distribution_kwargs
keyword arguments to be passed to the to_probability_distribution method of the estimator. For example, for DiscretePolicyEstimators, the kwargs can contain the temperature parameter, epsilon, and sf_bias.
- sample_actions(env, states)
Samples actions from the given states.
- Parameters
env (gfn.env.Env) – The environment to sample actions from.
states (States) – A batch of states.
- Returns
An Actions object containing the sampled actions.
- A tensor of shape (*batch_shape,) containing the log probabilities of
the sampled actions under the probability distribution of the given states.
- Return type
A tuple of tensors containing
- sample_trajectories(env, states=None, n_trajectories=None)
Sample trajectories sequentially.
- Parameters
env (gfn.env.Env) – The environment to sample trajectories from.
states (Optional[gfn.states.States]) – If given, trajectories would start from such states. Otherwise, trajectories are sampled from \(s_o\) and n_trajectories must be provided.
n_trajectories (Optional[int]) – If given, a batch of n_trajectories will be sampled all starting from the environment’s s_0.
- Return type
Returns: A Trajectories object representing the batch of sampled trajectories.
- Raises
AssertionError – When both states and n_trajectories are specified.
AssertionError – When states are not linear.
- Parameters
env (gfn.env.Env) –
states (Optional[gfn.states.States]) –
n_trajectories (Optional[int]) –
- Return type