turn_around training history
This commit is contained in:
@@ -6,7 +6,7 @@ from scripts.commons.UI import UI
|
||||
from shutil import copy
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback, CallbackList, BaseCallback
|
||||
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback, CallbackList, BaseCallback, StopTrainingOnNoModelImprovement
|
||||
from typing import Callable
|
||||
# from world.world import World
|
||||
from xml.dom import minidom
|
||||
@@ -266,11 +266,28 @@ class Train_Base():
|
||||
|
||||
evaluate = bool(eval_env is not None and eval_freq is not None)
|
||||
|
||||
# Optional early stop: stop training when eval reward does not improve for N eval rounds.
|
||||
no_improve_evals = int(os.environ.get("GYM_CPU_EARLY_STOP_NO_IMPROVE_EVALS", "0"))
|
||||
min_evals_before_stop = int(os.environ.get("GYM_CPU_EARLY_STOP_MIN_EVALS", "6"))
|
||||
stop_on_no_improve = None
|
||||
if evaluate and no_improve_evals > 0:
|
||||
stop_on_no_improve = StopTrainingOnNoModelImprovement(
|
||||
max_no_improvement_evals=no_improve_evals,
|
||||
min_evals=min_evals_before_stop,
|
||||
verbose=1,
|
||||
)
|
||||
|
||||
# Create evaluation callback
|
||||
eval_callback = None if not evaluate else EvalCallback(eval_env, n_eval_episodes=eval_eps, eval_freq=eval_freq,
|
||||
log_path=path,
|
||||
best_model_save_path=path, deterministic=True,
|
||||
render=False)
|
||||
eval_callback = None if not evaluate else EvalCallback(
|
||||
eval_env,
|
||||
n_eval_episodes=eval_eps,
|
||||
eval_freq=eval_freq,
|
||||
log_path=path,
|
||||
best_model_save_path=path,
|
||||
deterministic=True,
|
||||
render=False,
|
||||
callback_after_eval=stop_on_no_improve,
|
||||
)
|
||||
|
||||
# Create custom callback to display evaluations
|
||||
custom_callback = None if not evaluate else Cyclic_Callback(eval_freq,
|
||||
|
||||
Reference in New Issue
Block a user