Train
class Trainer
Class object to setup and carry the training.
Takes as input a generator that produces SR images. Conditionally, also a discriminator network and a feature extractor to build the components of the perceptual loss. Compiles the model(s) and trains in a GANS fashion if a discriminator is provided, otherwise carries a regular ISR training.
Args
-
generator: Keras model, the super-scaling, or generator, network.
-
discriminator: Keras model, the discriminator network for the adversarial component of the perceptual loss.
-
feature_extractor: Keras model, feature extractor network for the deep features component of perceptual loss function.
-
lr_train_dir: path to the directory containing the Low-Res images for training.
-
hr_train_dir: path to the directory containing the High-Res images for training.
-
lr_valid_dir: path to the directory containing the Low-Res images for validation.
-
hr_valid_dir: path to the directory containing the High-Res images for validation.
-
learning_rate: float.
-
loss_weights: dictionary, use to weigh the components of the loss function. Contains 'generator' for the generator loss component, and can contain 'discriminator' and 'feature_extractor' for the discriminator and deep features components respectively.
-
logs_dir: path to the directory where the tensorboard logs are saved.
-
weights_dir: path to the directory where the weights are saved.
-
dataname: string, used to identify what dataset is used for the training session.
-
weights_generator: path to the pre-trained generator's weights, for transfer learning.
-
weights_discriminator: path to the pre-trained discriminator's weights, for transfer learning.
-
n_validation: integer, number of validation samples used at training from the validation set.
-
flatness: dictionary. Determines determines the 'flatness' threshold level for the training patches. See the TrainerHelper class for more details.
-
lr_decay_frequency: integer, every how many epochs the learning rate is reduced.
-
lr_decay_factor: 0 < float <1, learning rate reduction multiplicative factor.
Methods
- train: combines the networks and triggers training with the specified settings.
__init__
def __init__(generator, discriminator, feature_extractor, lr_train_dir, hr_train_dir, lr_valid_dir, hr_valid_dir, loss_weights, log_dirs, fallback_save_every_n_epochs, dataname, weights_generator, weights_discriminator, n_validation, flatness, learning_rate, adam_optimizer, losses, metrics)
update_training_config
def update_training_config(settings)
Summarizes training setting.
train
def train(epochs, steps_per_epoch, batch_size, monitored_metrics)
Carries on the training for the given number of epochs. Sends the losses to Tensorboard.
Args
-
epochs: how many epochs to train for.
-
steps_per_epoch: how many batches epoch.
-
batch_size: amount of images per batch.
-
monitored_metrics: dictionary, the keys are the metrics that are monitored for the weights saving logic. The values are the mode that trigger the weights saving ('min' vs 'max').