Train and Evaluation
Train¶
Functions:
-
run–Trains the model. Can additionally evaluate on a testset, using best weights obtained during
run
¶
run(cfg: DictConfig) -> Tuple[dict, dict]
Trains the model. Can additionally evaluate on a testset, using best weights obtained during training. This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc.
Parameters:
-
cfg(DictConfig) –Configuration composed by Hydra.
Returns: Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
Source code in rl4co/tasks/train.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | |
Evaluate¶
Classes:
-
EvalBase–Base class for evaluation
-
GreedyEval–Evaluates the policy using greedy decoding and single trajectory
-
AugmentationEval–Evaluates the policy via N state augmentations
-
SamplingEval–Evaluates the policy via N samples from the policy
-
GreedyMultiStartEval–Evaluates the policy via
num_startsgreedy multistarts samples from the policy -
GreedyMultiStartAugmentEval–Evaluates the policy via
num_startssamples from the policy
Functions:
-
get_automatic_batch_size–Automatically reduces the batch size based on the eval function
EvalBase
¶
EvalBase(env, progress=True, **kwargs)
Base class for evaluation
Parameters:
-
env–Environment
-
progress–Whether to show progress bar
-
**kwargs–Additional arguments (to be implemented in subclasses)
Source code in rl4co/tasks/eval.py
29 30 31 32 | |
GreedyEval
¶
GreedyEval(env, **kwargs)
Bases: EvalBase
Evaluates the policy using greedy decoding and single trajectory
Source code in rl4co/tasks/eval.py
93 94 95 | |
AugmentationEval
¶
AugmentationEval(
env,
num_augment=8,
force_dihedral_8=False,
feats=None,
**kwargs
)
Bases: EvalBase
Evaluates the policy via N state augmentations
force_dihedral_8 forces the use of 8 augmentations (rotations and flips) as in POMO
https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8
Parameters:
-
num_augment(int, default:8) –Number of state augmentations
-
force_dihedral_8(bool, default:False) –Whether to force the use of 8 augmentations
Source code in rl4co/tasks/eval.py
119 120 121 122 123 124 125 126 | |
SamplingEval
¶
SamplingEval(
env,
samples,
softmax_temp=None,
select_best=True,
temperature=1.0,
top_p=0.0,
top_k=0,
**kwargs
)
Bases: EvalBase
Evaluates the policy via N samples from the policy
Parameters:
-
samples(int) –Number of samples to take
-
softmax_temp(float, default:None) –Temperature for softmax sampling. The higher the temperature, the more random the sampling
Source code in rl4co/tasks/eval.py
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | |
GreedyMultiStartEval
¶
GreedyMultiStartEval(env, num_starts=None, **kwargs)
Bases: EvalBase
Evaluates the policy via num_starts greedy multistarts samples from the policy
Parameters:
-
num_starts(int, default:None) –Number of greedy multistarts to use
Source code in rl4co/tasks/eval.py
211 212 213 214 215 216 | |
GreedyMultiStartAugmentEval
¶
GreedyMultiStartAugmentEval(
env,
num_starts=None,
num_augment=8,
force_dihedral_8=False,
feats=None,
**kwargs
)
Bases: EvalBase
Evaluates the policy via num_starts samples from the policy
and num_augment augmentations of each sample.force_dihedral_8` forces the use of 8 augmentations (rotations and flips) as in POMO
https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8
Parameters:
-
num_starts–Number of greedy multistart samples
-
num_augment–Number of augmentations per sample
-
force_dihedral_8–If True, force the use of 8 augmentations (rotations and flips) as in POMO
Source code in rl4co/tasks/eval.py
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 | |
get_automatic_batch_size
¶
get_automatic_batch_size(
eval_fn, start_batch_size=8192, max_batch_size=4096
)
Automatically reduces the batch size based on the eval function
Parameters:
-
eval_fn–The eval function
-
start_batch_size–The starting batch size. This should be the theoretical maximum batch size
-
max_batch_size–The maximum batch size. This is the practical maximum batch size
Source code in rl4co/tasks/eval.py
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 | |