Using custom models for CNN
To allow users to use custom models for encoding generation, we provide a CustomModel construct which serves as a wrapper for a user-defined feature extractor. The CustomModel consists of the following attributes:
name: The name of the custom model. Can be set to any string.model: A PyTorch model object, which is a subclass oftorch.nn.Moduleand implements theforwardmethod. The output of the forward method should be a tensor of shape (batch_size x features) . Alternatively, a__call__method is also accepted.transform: A function that transforms aPIL.Imageobject into a PyTorch tensor. Should correspond to the preprocessing logic of the supplied model.
CustomModel is provided while initializing the cnn object and can be used in the following 2 scenarios:
- Using the models provided with the
imagededuppackage. There are 3 models provided currently:MobileNetV3(MobileNetV3 Small)- This is the default.ViT(Vision Transformer- B16 IMAGENET1K_SWAG_E2E_V1)EfficientNet(EfficientNet B4- IMAGENET1K_V1)
from imagededup.methods import CNN
# Get CustomModel construct
from imagededup.utils import CustomModel
# Get the prepackaged models from imagededup
from imagededup.utils.models import ViT, MobilenetV3, EfficientNet
# Declare a custom config with CustomModel, the prepackaged models come with a name and transform function
custom_config = CustomModel(name=EfficientNet.name,
model=EfficientNet(),
transform=EfficientNet.transform)
# Use model_config argument to pass the custom config
cnn = CNN(model_config=custom_config)
# Use the model as usual
...
2.Using a user-defined custom model.
from imagededup.methods import CNN
# Get CustomModel construct
from imagededup.utils import CustomModel
# Import necessary pytorch constructs for initializing a custom feature extractor
import torch
from torchvision.transforms import transforms
# Declare custom feature extractor class
class MyModel(torch.nn.Module):
transform = transforms.Compose(
[
transforms.ToTensor()
]
)
name = 'my_custom_model'
def __init__(self):
super().__init__()
# Define the layers of the model here
def forward(self, x):
# Do something with x
return x
custom_config = CustomModel(name=MyModel.name,
model=MyModel(),
transform=MyModel.transform)
cnn = CNN(model_config=custom_config)
# Use the model as usual
...
It is not necessary to bundle name and transform functions with the model class. They can be passed separately as well.
Examples for both scenarios can be found in the examples section.