gfn.utils

Module Contents

Functions

trajectories_to_training_samples(trajectories, loss_fn)

Converts a Trajectories container to a States, Transitions or Trajectories container,

validate(env, parametrization[, n_validation_samples, ...])

Evaluates the current parametrization on the given environment.

gfn.utils.trajectories_to_training_samples(trajectories, loss_fn)

Converts a Trajectories container to a States, Transitions or Trajectories container, depending on the loss.

Parameters
Return type

States | Transitions | Trajectories

gfn.utils.validate(env, parametrization, n_validation_samples=1000, visited_terminating_states=None)

Evaluates the current parametrization on the given environment. This is for environments with known target reward. The validation is done by computing the l1 distance between the learned empirical and the target distributions.

Parameters
  • env (gfn.envs.Env) – The environment to evaluate the parametrization on.

  • parametrization (gfn.losses.Parametrization) – The parametrization to evaluate.

  • n_validation_samples (int) – The number of samples to use to evaluate the pmf.

  • visited_terminating_states (Optional[gfn.containers.States]) – The terminating states visited during training. If given, the pmf is obtained from these last n_validation_samples states. Otherwise, n_validation_samples are resampled for evaluation.

Returns

A dictionary containing the l1 validation metric. If the parametrization is a TBParametrization, i.e. contains LogZ, then the (absolute) difference between the learned and the target LogZ is also returned in the dictionary.

Return type

Dict[str, float]