Skip to content


class AWS

Cloud provider class that allows to run training on AWS.

Cloud instances are created and destroyed using Terraform. The instance is provisioned with nvidia-docker and training is run in a Docker container using the public Docker image idealo/tensorflow-image-atm:1.13.1.

All commands on EC2 instance will be run via SSH.

For training the local image and job directories will be synced with S3. After training the trained models will be synced with S3 and the local job directory.

  • tf_dir: Directory with Terraform files for AWS setup.

  • region: AWS region [eu-west-1, eu-central-1].

  • instance_type: AWS GPU instance type [g2.*, p2.*, p3.*].

  • vpc_id: AWS Virtual Private Cloud ID.

  • s3_bucket: AWS S3 bucket where all training files will be stored (is not created by Terraform).

  • job_dir: Job directory on local system (needed for logging).

  • cloud_tag: Name under which all AWS resources will be set up.


def __init__(tf_dir, region, instance_type, vpc_id, s3_bucket, job_dir, cloud_tag, **kwargs)

Inits cloud component.

Sets remote workdir and ensures that s3 bucket prefix is correct.


def init()

Runs Terraform initialization.


def apply()

Runs Terraform apply.


def train(image_dir, job_dir, **kwargs)

Runs training on EC2 instance.

The following steps will be performed in sequence:

  • syncs local image and job directory with S3

  • syncs S3 with EC2 instance

  • launches Docker training container on EC2

  • syncs EC2 with S3

  • syncs S3 with local.

Any of the pre-trained CNNs in Keras can be used.

  • image_dir: Directory with images used for training.

  • job_dir: Directory with train_samples.json, val_samples.json, and class_mapping.json.

  • epochs_train_dense: Number of epochs to train dense layers.

  • epochs_train_all: Number of epochs to train all layers.

  • learning_rate_dense: Learning rate for dense training phase.

  • self.learning_rate_all: Learning rate for all training phase.

  • batch_size: Number of images per batch.

  • dropout_rate: Fraction set randomly.

  • base_model_name: Name of pretrained CNN.


def destroy()

Runs Terraform destroy.