gfn.samplers

Submodules

Package Contents

Classes

ActionsSampler

Base class for action sampling methods.

BackwardDiscreteActionsSampler

For sampling backward actions in discrete environments.

DiscreteActionsSampler

For Discrete environments.

TrajectoriesSampler

class gfn.samplers.ActionsSampler

Bases: abc.ABC

Base class for action sampling methods.

abstract sample(states)
Parameters

states (States) – A batch of states.

Returns

A tuple of tensors containing the log probabilities of the sampled actions, and the sampled actions.

Return type

Tuple[Tensor[batch_size], Tensor[batch_size]]

class gfn.samplers.BackwardDiscreteActionsSampler(estimator, temperature=1.0, epsilon=0.0)

Bases: DiscreteActionsSampler, BackwardActionsSampler

For sampling backward actions in discrete environments.

Parameters
get_logits(states)

Transforms the raw logits by masking illegal actions.

Raises

ValueError – if one of the resulting logits is NaN.

Returns

A 2D tensor of shape (batch_size, n_actions) containing the transformed logits.

Return type

Tensor2D

Parameters

states (gfn.containers.states.States) –

get_probs(states)
Returns

The probabilities of each action in each state in the batch.

Parameters

states (gfn.containers.states.States) –

Return type

Tensor2D

class gfn.samplers.DiscreteActionsSampler(estimator, temperature=1.0, sf_bias=0.0, epsilon=0.0)

Bases: ActionsSampler

For Discrete environments.

Parameters
get_logits(states)

Transforms the raw logits by masking illegal actions.

Raises

ValueError – if one of the resulting logits is NaN.

Returns

A 2D tensor of shape (batch_size, n_actions) containing the transformed logits.

Return type

Tensor2D

Parameters

states (gfn.containers.states.States) –

get_probs(states)
Returns

The probabilities of each action in each state in the batch.

Parameters

states (gfn.containers.states.States) –

Return type

Tensor2D

get_raw_logits(states)

This is before illegal actions are masked out and the exit action is biased. Should be used for Discrete action spaces only.

Returns

A 2D tensor of shape (batch_size, n_actions) containing the logits for each action in each state in the batch.

Return type

Tensor2D

Parameters

states (gfn.containers.states.States) –

sample(states)
Parameters

states (States) – A batch of states.

Returns

A tuple of tensors containing the log probabilities of the sampled actions, and the sampled actions.

Return type

Tuple[Tensor[batch_size], Tensor[batch_size]]

class gfn.samplers.TrajectoriesSampler(env, actions_sampler)
Parameters
sample(n_trajectories)
Parameters

n_trajectories (int) –

Return type

gfn.containers.Trajectories

sample_trajectories(states=None, n_trajectories=None)
Parameters
Return type

gfn.containers.Trajectories