gfn.samplers.actions_samplers

Module Contents

Classes

ActionsSampler

Base class for action sampling methods.

BackwardActionsSampler

Base class for backward action sampling methods.

BackwardDiscreteActionsSampler

For sampling backward actions in discrete environments.

DiscreteActionsSampler

For Discrete environments.

Attributes

Tensor1D

Tensor2D

Tensor2D2

class gfn.samplers.actions_samplers.ActionsSampler

Bases: abc.ABC

Base class for action sampling methods.

abstract sample(states)
Parameters

states (States) – A batch of states.

Returns

A tuple of tensors containing the log probabilities of the sampled actions, and the sampled actions.

Return type

Tuple[Tensor[batch_size], Tensor[batch_size]]

class gfn.samplers.actions_samplers.BackwardActionsSampler

Bases: ActionsSampler

Base class for backward action sampling methods.

class gfn.samplers.actions_samplers.BackwardDiscreteActionsSampler(estimator, temperature=1.0, epsilon=0.0)

Bases: DiscreteActionsSampler, BackwardActionsSampler

For sampling backward actions in discrete environments.

Parameters
get_logits(states)

Transforms the raw logits by masking illegal actions.

Raises

ValueError – if one of the resulting logits is NaN.

Returns

A 2D tensor of shape (batch_size, n_actions) containing the transformed logits.

Return type

Tensor2D

Parameters

states (gfn.containers.states.States) –

get_probs(states)
Returns

The probabilities of each action in each state in the batch.

Parameters

states (gfn.containers.states.States) –

Return type

Tensor2D

class gfn.samplers.actions_samplers.DiscreteActionsSampler(estimator, temperature=1.0, sf_bias=0.0, epsilon=0.0)

Bases: ActionsSampler

For Discrete environments.

Parameters
get_logits(states)

Transforms the raw logits by masking illegal actions.

Raises

ValueError – if one of the resulting logits is NaN.

Returns

A 2D tensor of shape (batch_size, n_actions) containing the transformed logits.

Return type

Tensor2D

Parameters

states (gfn.containers.states.States) –

get_probs(states)
Returns

The probabilities of each action in each state in the batch.

Parameters

states (gfn.containers.states.States) –

Return type

Tensor2D

get_raw_logits(states)

This is before illegal actions are masked out and the exit action is biased. Should be used for Discrete action spaces only.

Returns

A 2D tensor of shape (batch_size, n_actions) containing the logits for each action in each state in the batch.

Return type

Tensor2D

Parameters

states (gfn.containers.states.States) –

sample(states)
Parameters

states (States) – A batch of states.

Returns

A tuple of tensors containing the log probabilities of the sampled actions, and the sampled actions.

Return type

Tuple[Tensor[batch_size], Tensor[batch_size]]

gfn.samplers.actions_samplers.Tensor1D
gfn.samplers.actions_samplers.Tensor2D
gfn.samplers.actions_samplers.Tensor2D2