gfn.losses.flow_matching

Module Contents

Classes

FMParametrization

$mathcal{O}_{edge}$ is the set of functions from the non-terminating edges

FlowMatching

Attributes

LossTensor

ScoresTensor

class gfn.losses.flow_matching.FMParametrization

Bases: gfn.losses.base.Parametrization

\(\mathcal{O}_{edge}\) is the set of functions from the non-terminating edges to \(\mathbb{R}^+\). Which is equivalent to the set of functions from the internal nodes (i.e. without \(s_f\)) to \((\mathbb{R})^{n_actions}\), without the exit action (No need for positivity if we parametrize log-flows).

logF :gfn.estimators.LogEdgeFlowEstimator
Pi(env, n_samples=1000, **actions_sampler_kwargs)
Parameters
Return type

gfn.distributions.TrajectoryDistribution

class gfn.losses.flow_matching.FlowMatching(parametrization, alpha=1.0)

Bases: gfn.losses.base.StateDecomposableLoss

Parameters

parametrization (FMParametrization) –

__call__(states_tuple)
Parameters

states_tuple (Tuple[gfn.containers.states.States, gfn.containers.states.States]) –

Return type

LossTensor

flow_matching_loss(states)

Compute the FM for the given states, defined as the log-sum incoming flows minus log-sum outgoing flows. The states should not include s0. The batch shape should be (n_states,).

As of now, only discrete environments are handled.

Parameters

states (gfn.containers.states.States) –

Return type

ScoresTensor

reward_matching_loss(terminating_states)
Parameters

terminating_states (gfn.containers.states.States) –

Return type

LossTensor

gfn.losses.flow_matching.LossTensor
gfn.losses.flow_matching.ScoresTensor