Skip to content

Commit

Permalink
Add ability to use custom CNN models (#190)
Browse files Browse the repository at this point in the history
* Bump version to 0.3.1.

* Add constructor to do custom model initialization. Add arguments to default constructor too.

* Use model config to support alternate models.

* Rename model config to custom_model and instead add model_config as a constructor param.

* Add models.py

* Refactor custom model to be a inherited type of namedtuple, allows for better type hinting.

* Add efficientNet b4, remove Vit_h_14 as it was too large. Also add a name attribute to models.

* Add test cases and docstrings.

* Add user guide for custom model usage, add test case for custom forward call acceptance.

* Add example notebook for using custom models.

* Fix some docstrings.

* Update mkdocs builspec to include custom model documentation.

* Update custom model documentation.

* Update cnn constructor documentation.

* Update Readme to include info about custom cnn models.
  • Loading branch information
tanujjain authored Apr 28, 2023
1 parent f64e4cc commit 837d86a
Show file tree
Hide file tree
Showing 10 changed files with 747 additions and 91 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ framework is also provided to judge the quality of deduplication for a given dat
Following details the functionality provided by the package:

- Finding duplicates in a directory using one of the following algorithms:
- [Convolutional Neural Network](https://arxiv.org/abs/1704.04861) (CNN)
- [Convolutional Neural Network](https://arxiv.org/abs/1905.02244#:~:text=MobileNetV3%20is%20tuned%20to%20mobile,improved%20through%20novel%20architecture%20advances.) (CNN) - Select from several prepackaged models or provide your own custom model.
- [Perceptual hashing](http://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html) (PHash)
- [Difference hashing](http://www.hackerfactor.com/blog/index.php?/archives/529-Kind-of-Like-That.html) (DHash)
- [Wavelet hashing](https://fullstackml.com/wavelet-image-hash-in-python-3504fdd282b5) (WHash)
Expand Down Expand Up @@ -118,8 +118,9 @@ plot_duplicates(image_dir='path/to/image/directory',
duplicate_map=duplicates,
filename='ukbench00120.jpg')
```
It is also possible to use your own custom models for finding duplicates using the CNN method.

For more examples, refer [this](https://github.com/idealo/imagededup/tree/master/examples) part of the
For examples, refer [this](https://github.com/idealo/imagededup/tree/master/examples) part of the
repository.

For more detailed usage of the package functionality, refer: [https://idealo.github.io/imagededup/](https://idealo.github.io/imagededup/)
Expand Down
324 changes: 324 additions & 0 deletions examples/use_custom_model.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from pathlib import Path"
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"image_dir = Path('../tests/data/mixed_images')"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"# Use one of the prepackaged models"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from imagededup.methods import CNN\n",
"\n",
"# Get CustomModel construct\n",
"from imagededup.utils import CustomModel\n",
"\n",
"# Get efficientnet model\n",
"from imagededup.utils.models import EfficientNet\n",
"# Other models include ViT, MobilenetV3 (default selection)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# Declare a custom config with CustomModel, the prepackaged models come with a name and transform function\n",
"custom_config = CustomModel(name=EfficientNet.name,\n",
" model=EfficientNet(),\n",
" transform=EfficientNet.transform)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"cnn_encoder = CNN(model_config=custom_config)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"duplicates_cnn = cnn_encoder.find_duplicates(image_dir=image_dir, scores=True)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"duplicates_cnn"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"# User-defined model"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from imagededup.methods import CNN\n",
"\n",
"# Get CustomModel construct\n",
"from imagededup.utils import CustomModel\n",
"\n",
"# Import necessary pytorch constructs for initializing a custom feature extractor\n",
"import torch\n",
"from torchvision.transforms import transforms\n",
"\n",
"# Declare custom feature extractor class\n",
"class MyModel(torch.nn.Module):\n",
" transform = transforms.Compose(\n",
" [\n",
" transforms.Resize((256, 256)),\n",
" transforms.ToTensor()\n",
" ]\n",
" ) # transform must take PIL.Image as input and return a torch.Tensor\n",
"\n",
" name = 'my_custom_model' # name can be any user-defined string\n",
"\n",
" def __init__(self):\n",
" super().__init__()\n",
" # Define the layers of the model here\n",
"\n",
" def forward(self, x):\n",
" # Add more operations here\n",
" x = x.view(-1, 256*256*3) # output shape: batch_size x features\n",
" return x"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# Initialize the CNN using model_config parameter and setting it to the custom model\n",
"custom_config = CustomModel(name=MyModel.name,\n",
" model=MyModel(),\n",
" transform=MyModel.transform)\n",
"\n",
"cnn = CNN(model_config=custom_config)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"duplicates_cnn = cnn.find_duplicates(image_dir=image_dir, scores=True)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"duplicates_cnn"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"# Using a huggingface model"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"!pip install transformers"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from imagededup.methods import CNN\n",
"\n",
"# Get CustomModel construct\n",
"from imagededup.utils import CustomModel\n",
"\n",
"# Import necessary constructs for initializing a huggingface transformers model\n",
"from transformers import ViTModel, AutoImageProcessor\n",
"import torch\n",
"from torchvision.transforms import transforms\n",
"\n",
"VIT_MODEL = \"google/vit-base-patch16-224-in21k\"\n",
"\n",
"def vit_transform(image):\n",
" transform = AutoImageProcessor.from_pretrained(VIT_MODEL)\n",
" x = transform(image, return_tensors = 'pt')['pixel_values']\n",
" return x\n",
"\n",
"class VitHgface(torch.nn.Module):\n",
" transform = transforms.Lambda(vit_transform)\n",
"\n",
" name = 'ViT_hgface'\n",
"\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.vit = ViTModel.from_pretrained(VIT_MODEL)\n",
"\n",
" def forward(self, x):\n",
" x = x.view(-1, 3, 224, 224)\n",
" with torch.no_grad():\n",
" out = self.vit(pixel_values=x)\n",
" return out.pooler_output"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"custom_config = CustomModel(name=VitHgface.name,\n",
" model=VitHgface(),\n",
" transform=VitHgface.transform)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"cnn = CNN(model_config=custom_config)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"duplicates_cnn = cnn.find_duplicates(image_dir=image_dir, scores=True)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"duplicates_cnn"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"enc = cnn.encode_images(image_dir=image_dir)"
],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Loading

0 comments on commit 837d86a

Please sign in to comment.