gfn.samplers

Module Contents

Classes

Sampler

`Sampler is a container for a PolicyEstimator.

class gfn.samplers.Sampler(estimator, **probability_distribution_kwargs)

`Sampler is a container for a PolicyEstimator.

Can be used to either sample individual actions, sample trajectories from \(s_0\), or complete a batch of partially-completed trajectories from a given batch states.

Parameters
estimator

the submitted PolicyEstimator.

probability_distribution_kwargs

keyword arguments to be passed to the to_probability_distribution method of the estimator. For example, for DiscretePolicyEstimators, the kwargs can contain the temperature parameter, epsilon, and sf_bias.

sample_actions(env, states)

Samples actions from the given states.

Parameters
  • env (gfn.env.Env) – The environment to sample actions from.

  • states (States) – A batch of states.

Returns

  • An Actions object containing the sampled actions.

  • A tensor of shape (*batch_shape,) containing the log probabilities of

    the sampled actions under the probability distribution of the given states.

Return type

A tuple of tensors containing

sample_trajectories(env, states=None, n_trajectories=None)

Sample trajectories sequentially.

Parameters
  • env (gfn.env.Env) – The environment to sample trajectories from.

  • states (Optional[gfn.states.States]) – If given, trajectories would start from such states. Otherwise, trajectories are sampled from \(s_o\) and n_trajectories must be provided.

  • n_trajectories (Optional[int]) – If given, a batch of n_trajectories will be sampled all starting from the environment’s s_0.

Return type

gfn.containers.Trajectories

Returns: A Trajectories object representing the batch of sampled trajectories.

Raises
  • AssertionError – When both states and n_trajectories are specified.

  • AssertionError – When states are not linear.

Parameters
Return type

gfn.containers.Trajectories