-
Notifications
You must be signed in to change notification settings - Fork 457
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ability to use custom CNN models (#190)
* Bump version to 0.3.1. * Add constructor to do custom model initialization. Add arguments to default constructor too. * Use model config to support alternate models. * Rename model config to custom_model and instead add model_config as a constructor param. * Add models.py * Refactor custom model to be a inherited type of namedtuple, allows for better type hinting. * Add efficientNet b4, remove Vit_h_14 as it was too large. Also add a name attribute to models. * Add test cases and docstrings. * Add user guide for custom model usage, add test case for custom forward call acceptance. * Add example notebook for using custom models. * Fix some docstrings. * Update mkdocs builspec to include custom model documentation. * Update custom model documentation. * Update cnn constructor documentation. * Update Readme to include info about custom cnn models.
- Loading branch information
Showing
10 changed files
with
747 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,324 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from pathlib import Path" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"image_dir = Path('../tests/data/mixed_images')" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# Use one of the prepackaged models" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"from imagededup.methods import CNN\n", | ||
"\n", | ||
"# Get CustomModel construct\n", | ||
"from imagededup.utils import CustomModel\n", | ||
"\n", | ||
"# Get efficientnet model\n", | ||
"from imagededup.utils.models import EfficientNet\n", | ||
"# Other models include ViT, MobilenetV3 (default selection)" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"# Declare a custom config with CustomModel, the prepackaged models come with a name and transform function\n", | ||
"custom_config = CustomModel(name=EfficientNet.name,\n", | ||
" model=EfficientNet(),\n", | ||
" transform=EfficientNet.transform)" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"cnn_encoder = CNN(model_config=custom_config)" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"duplicates_cnn = cnn_encoder.find_duplicates(image_dir=image_dir, scores=True)" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"duplicates_cnn" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# User-defined model" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"from imagededup.methods import CNN\n", | ||
"\n", | ||
"# Get CustomModel construct\n", | ||
"from imagededup.utils import CustomModel\n", | ||
"\n", | ||
"# Import necessary pytorch constructs for initializing a custom feature extractor\n", | ||
"import torch\n", | ||
"from torchvision.transforms import transforms\n", | ||
"\n", | ||
"# Declare custom feature extractor class\n", | ||
"class MyModel(torch.nn.Module):\n", | ||
" transform = transforms.Compose(\n", | ||
" [\n", | ||
" transforms.Resize((256, 256)),\n", | ||
" transforms.ToTensor()\n", | ||
" ]\n", | ||
" ) # transform must take PIL.Image as input and return a torch.Tensor\n", | ||
"\n", | ||
" name = 'my_custom_model' # name can be any user-defined string\n", | ||
"\n", | ||
" def __init__(self):\n", | ||
" super().__init__()\n", | ||
" # Define the layers of the model here\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" # Add more operations here\n", | ||
" x = x.view(-1, 256*256*3) # output shape: batch_size x features\n", | ||
" return x" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"# Initialize the CNN using model_config parameter and setting it to the custom model\n", | ||
"custom_config = CustomModel(name=MyModel.name,\n", | ||
" model=MyModel(),\n", | ||
" transform=MyModel.transform)\n", | ||
"\n", | ||
"cnn = CNN(model_config=custom_config)" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"duplicates_cnn = cnn.find_duplicates(image_dir=image_dir, scores=True)" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"duplicates_cnn" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# Using a huggingface model" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install transformers" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"from imagededup.methods import CNN\n", | ||
"\n", | ||
"# Get CustomModel construct\n", | ||
"from imagededup.utils import CustomModel\n", | ||
"\n", | ||
"# Import necessary constructs for initializing a huggingface transformers model\n", | ||
"from transformers import ViTModel, AutoImageProcessor\n", | ||
"import torch\n", | ||
"from torchvision.transforms import transforms\n", | ||
"\n", | ||
"VIT_MODEL = \"google/vit-base-patch16-224-in21k\"\n", | ||
"\n", | ||
"def vit_transform(image):\n", | ||
" transform = AutoImageProcessor.from_pretrained(VIT_MODEL)\n", | ||
" x = transform(image, return_tensors = 'pt')['pixel_values']\n", | ||
" return x\n", | ||
"\n", | ||
"class VitHgface(torch.nn.Module):\n", | ||
" transform = transforms.Lambda(vit_transform)\n", | ||
"\n", | ||
" name = 'ViT_hgface'\n", | ||
"\n", | ||
" def __init__(self):\n", | ||
" super().__init__()\n", | ||
" self.vit = ViTModel.from_pretrained(VIT_MODEL)\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" x = x.view(-1, 3, 224, 224)\n", | ||
" with torch.no_grad():\n", | ||
" out = self.vit(pixel_values=x)\n", | ||
" return out.pooler_output" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"custom_config = CustomModel(name=VitHgface.name,\n", | ||
" model=VitHgface(),\n", | ||
" transform=VitHgface.transform)" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"cnn = CNN(model_config=custom_config)" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"duplicates_cnn = cnn.find_duplicates(image_dir=image_dir, scores=True)" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"duplicates_cnn" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"enc = cnn.encode_images(image_dir=image_dir)" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
Oops, something went wrong.