Skip to content

Train

class Trainer

Class object to setup and carry the training.

Takes as input a generator that produces SR images. Conditionally, also a discriminator network and a feature extractor to build the components of the perceptual loss. Compiles the model(s) and trains in a GANS fashion if a discriminator is provided, otherwise carries a regular ISR training.

Args
  • generator: Keras model, the super-scaling, or generator, network.

  • discriminator: Keras model, the discriminator network for the adversarial component of the perceptual loss.

  • feature_extractor: Keras model, feature extractor network for the deep features component of perceptual loss function.

  • lr_train_dir: path to the directory containing the Low-Res images for training.

  • hr_train_dir: path to the directory containing the High-Res images for training.

  • lr_valid_dir: path to the directory containing the Low-Res images for validation.

  • hr_valid_dir: path to the directory containing the High-Res images for validation.

  • learning_rate: float.

  • loss_weights: dictionary, use to weigh the components of the loss function. Contains 'generator' for the generator loss component, and can contain 'discriminator' and 'feature_extractor' for the discriminator and deep features components respectively.

  • logs_dir: path to the directory where the tensorboard logs are saved.

  • weights_dir: path to the directory where the weights are saved.

  • dataname: string, used to identify what dataset is used for the training session.

  • weights_generator: path to the pre-trained generator's weights, for transfer learning.

  • weights_discriminator: path to the pre-trained discriminator's weights, for transfer learning.

  • n_validation: integer, number of validation samples used at training from the validation set.

  • flatness: dictionary. Determines determines the 'flatness' threshold level for the training patches. See the TrainerHelper class for more details.

  • lr_decay_frequency: integer, every how many epochs the learning rate is reduced.

  • lr_decay_factor: 0 < float <1, learning rate reduction multiplicative factor.

Methods
  • train: combines the networks and triggers training with the specified settings.

__init__

def __init__(generator, discriminator, feature_extractor, lr_train_dir, hr_train_dir, lr_valid_dir, hr_valid_dir, loss_weights, log_dirs, fallback_save_every_n_epochs, dataname, weights_generator, weights_discriminator, n_validation, flatness, learning_rate, adam_optimizer, losses, metrics)

update_training_config

def update_training_config(settings)

Summarizes training setting.

train

def train(epochs, steps_per_epoch, batch_size, monitored_metrics)

Carries on the training for the given number of epochs. Sends the losses to Tensorboard.

Args
  • epochs: how many epochs to train for.

  • steps_per_epoch: how many batches epoch.

  • batch_size: amount of images per batch.

  • monitored_metrics: dictionary, the keys are the metrics that are monitored for the weights saving logic. The values are the mode that trigger the weights saving ('min' vs 'max').