gfn.utils.common

Module Contents

Functions

get_terminating_state_dist_pmf(env, states)

validate(env, gflownet[, n_validation_samples, ...])

Evaluates the current gflownet on the given environment.

gfn.utils.common.get_terminating_state_dist_pmf(env, states)
Parameters
Return type

torchtyping.TensorType[n_states, float]

gfn.utils.common.validate(env, gflownet, n_validation_samples=1000, visited_terminating_states=None)

Evaluates the current gflownet on the given environment.

This is for environments with known target reward. The validation is done by computing the l1 distance between the learned empirical and the target distributions.

Parameters
  • env (gfn.env.Env) – The environment to evaluate the gflownet on.

  • gflownet (gfn.gflownet.GFlowNet) – The gflownet to evaluate.

  • n_validation_samples (int) – The number of samples to use to evaluate the pmf.

  • visited_terminating_states (Optional[gfn.states.States]) – The terminating states visited during training. If given, the pmf is obtained from these last n_validation_samples states. Otherwise, n_validation_samples are resampled for evaluation.

Return type

Dict[str, float]

Returns: A dictionary containing the l1 validation metric. If the gflownet

is a TBGFlowNet, i.e. contains LogZ, then the (absolute) difference between the learned and the target LogZ is also returned in the dictionary.