Train Helper
class TrainerHelper
Collection of useful functions to manage training sessions.
Args
-
generator: Keras model, the super-scaling, or generator, network.
-
logs_dir: path to the directory where the tensorboard logs are saved.
-
weights_dir: path to the directory where the weights are saved.
-
lr_train_dir: path to the directory containing the Low-Res images.
-
feature_extractor: Keras model, feature extractor network for the deep features component of perceptual loss function.
-
discriminator: Keras model, the discriminator network for the adversarial component of the perceptual loss.
-
dataname: string, used to identify what dataset is used for the training session.
-
fallback_save_every_n_epochs: integer, determines after how many epochs that did not trigger weights saving the weights are despite no metric improvement.
-
max_n_best_weights: maximum amount of weights that are best on some metric that are kept.
-
max_n_other_weights: maximum amount of non-best weights that are kept.
Methods
-
print_training_setting: see docstring.
-
on_epoch_end: see docstring.
-
epoch_n_from_weights_name: see docstring.
-
initialize_training: see docstring.
__init__
def __init__(generator, weights_dir, logs_dir, lr_train_dir, feature_extractor, discriminator, dataname, weights_generator, weights_discriminator, fallback_save_every_n_epochs, max_n_other_weights, max_n_best_weights)
get_session_id
def get_session_id(basename)
Returns unique session identifier.
update_config
def update_config(training_settings)
Adds to the existing settings (if any) the current settings dictionary under the session_id key.
print_training_setting
def print_training_setting(settings)
Does what it says.
on_epoch_end
def on_epoch_end(epoch, losses, generator, discriminator, metrics)
Manages the operations that are taken at the end of each epoch: metric checks, weight saves, logging.
epoch_n_from_weights_name
def epoch_n_from_weights_name(w_name)
Extracts the last epoch number from the standardized weights name. Only works if the weights contain 'epoch' followed by 3 integers, for example: some-architectureepoch023suffix.hdf5
initialize_training
def initialize_training(object)
Function that is exectured prior to training.
Wraps up most of the functions of this class: load the weights if any are given, generaters names for session and weights, creates directories and prints the training session.