gfn.losses.flow_matching
Module Contents
Classes
$mathcal{O}_{edge}$ is the set of functions from the non-terminating edges |
|
Attributes
- class gfn.losses.flow_matching.FMParametrization
Bases:
gfn.losses.base.Parametrization\(\mathcal{O}_{edge}\) is the set of functions from the non-terminating edges to \(\mathbb{R}^+\). Which is equivalent to the set of functions from the internal nodes (i.e. without \(s_f\)) to \((\mathbb{R})^{n_actions}\), without the exit action (No need for positivity if we parametrize log-flows).
- logF :gfn.estimators.LogEdgeFlowEstimator
- Pi(env, n_samples=1000, **actions_sampler_kwargs)
- Parameters
env (gfn.envs.Env) –
n_samples (int) –
- Return type
- class gfn.losses.flow_matching.FlowMatching(parametrization, alpha=1.0)
Bases:
gfn.losses.base.StateDecomposableLoss- Parameters
parametrization (FMParametrization) –
- __call__(states_tuple)
- Parameters
states_tuple (Tuple[gfn.containers.states.States, gfn.containers.states.States]) –
- Return type
LossTensor
- flow_matching_loss(states)
Compute the FM for the given states, defined as the log-sum incoming flows minus log-sum outgoing flows. The states should not include s0. The batch shape should be (n_states,).
As of now, only discrete environments are handled.
- Parameters
states (gfn.containers.states.States) –
- Return type
ScoresTensor
- reward_matching_loss(terminating_states)
- Parameters
terminating_states (gfn.containers.states.States) –
- Return type
LossTensor
- gfn.losses.flow_matching.LossTensor
- gfn.losses.flow_matching.ScoresTensor