gfn.utils.common
Module Contents
Functions
|
|
|
Evaluates the current gflownet on the given environment. |
- gfn.utils.common.get_terminating_state_dist_pmf(env, states)
- Parameters
env (gfn.env.Env) –
states (gfn.states.States) –
- Return type
torchtyping.TensorType[n_states, float]
- gfn.utils.common.validate(env, gflownet, n_validation_samples=1000, visited_terminating_states=None)
Evaluates the current gflownet on the given environment.
This is for environments with known target reward. The validation is done by computing the l1 distance between the learned empirical and the target distributions.
- Parameters
env (gfn.env.Env) – The environment to evaluate the gflownet on.
gflownet (gfn.gflownet.GFlowNet) – The gflownet to evaluate.
n_validation_samples (int) – The number of samples to use to evaluate the pmf.
visited_terminating_states (Optional[gfn.states.States]) – The terminating states visited during training. If given, the pmf is obtained from these last n_validation_samples states. Otherwise, n_validation_samples are resampled for evaluation.
- Return type
Dict[str, float]
- Returns: A dictionary containing the l1 validation metric. If the gflownet
is a TBGFlowNet, i.e. contains LogZ, then the (absolute) difference between the learned and the target LogZ is also returned in the dictionary.