Skip to content

Train Helper

class TrainerHelper

Collection of useful functions to manage training sessions.

Args
  • generator: Keras model, the super-scaling, or generator, network.

  • logs_dir: path to the directory where the tensorboard logs are saved.

  • weights_dir: path to the directory where the weights are saved.

  • lr_train_dir: path to the directory containing the Low-Res images.

  • feature_extractor: Keras model, feature extractor network for the deep features component of perceptual loss function.

  • discriminator: Keras model, the discriminator network for the adversarial component of the perceptual loss.

  • dataname: string, used to identify what dataset is used for the training session.

  • fallback_save_every_n_epochs: integer, determines after how many epochs that did not trigger weights saving the weights are despite no metric improvement.

  • max_n_best_weights: maximum amount of weights that are best on some metric that are kept.

  • max_n_other_weights: maximum amount of non-best weights that are kept.

Methods
  • print_training_setting: see docstring.

  • on_epoch_end: see docstring.

  • epoch_n_from_weights_name: see docstring.

  • initialize_training: see docstring.

__init__

def __init__(generator, weights_dir, logs_dir, lr_train_dir, feature_extractor, discriminator, dataname, weights_generator, weights_discriminator, fallback_save_every_n_epochs, max_n_other_weights, max_n_best_weights)

get_session_id

def get_session_id(basename)

Returns unique session identifier.

update_config

def update_config(training_settings)

Adds to the existing settings (if any) the current settings dictionary under the session_id key.

def print_training_setting(settings)

Does what it says.

on_epoch_end

def on_epoch_end(epoch, losses, generator, discriminator, metrics)

Manages the operations that are taken at the end of each epoch: metric checks, weight saves, logging.

epoch_n_from_weights_name

def epoch_n_from_weights_name(w_name)

Extracts the last epoch number from the standardized weights name. Only works if the weights contain 'epoch' followed by 3 integers, for example: some-architectureepoch023suffix.hdf5

initialize_training

def initialize_training(object)

Function that is exectured prior to training.

Wraps up most of the functions of this class: load the weights if any are given, generaters names for session and weights, creates directories and prints the training session.