Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Script for ranking dataset using trained Discriminator #6

Open
gwern opened this issue Dec 4, 2018 · 0 comments
Open

Script for ranking dataset using trained Discriminator #6

gwern opened this issue Dec 4, 2018 · 0 comments

Comments

@gwern
Copy link

gwern commented Dec 4, 2018

One useful trick for data cleaning is taking a trained Discriminator and using it to find the 'worst' samples and either manually reviewing them or automatically deleting them.

I modified the training script to support this by ranking a folder dataset using a D model. The default dataloader doesn't preserve image paths, so an additional dataset has to be added:

diff --git a/data_loader.py b/data_loader.py
index 736362c..c4528a8 100755
--- a/data_loader.py
+++ b/data_loader.py
@@ -29,7 +29,7 @@ class Data_Loader():
         transforms = self.transform(True, True, True, False)
         dataset = dsets.LSUN(self.path, classes=classes, transform=transforms)
         return dataset
-    
+^M
     def load_imagenet(self):
         transforms = self.transform(True, True, True, True)
         dataset = dsets.ImageFolder(self.path+'/imagenet', transform=transforms)
@@ -42,9 +42,15 @@ class Data_Loader():
 
     def load_off(self):
         transforms = self.transform(True, True, True, False)
         dataset = dsets.ImageFolder(self.path, transform=transforms)
         return dataset
 
+    def load_rank(self):^M
+        transforms = self.transform(True, True, True, False)^M
+        dataset = ImageFolderWithPaths(self.path, transform=transforms)^M
+        return dataset^M
+^M
     def loader(self):
         if self.dataset == 'lsun':
             dataset = self.load_lsun()
@@ -54,6 +60,8 @@ class Data_Loader():
             dataset = self.load_celeb()
         elif self.dataset == 'off':
             dataset = self.load_off()
+        elif self.dataset == 'rank':^M
+            dataset = self.load_rank()^M
 
         print('dataset',len(dataset))
         loader = torch.utils.data.DataLoader(dataset=dataset,
@@ -63,3 +71,18 @@ class Data_Loader():
                                               drop_last=True)
         return loader
 
+^M
+class ImageFolderWithPaths(dsets.ImageFolder):^M
+    """Custom dataset that includes image file paths. Extends^M
+    torchvision.datasets.ImageFolder^M
+    """^M
+^M
+    # override the __getitem__ method. this is the method dataloader calls^M
+    def __getitem__(self, index):^M
+        # this is what ImageFolder normally returns^M
+        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)^M
+        # the image file path^M
+        path = self.imgs[index][0]^M
+        # make a new tuple that includes original and the path^M
+        tuple_with_path = (original_tuple + (path,))^M
+        return tuple_with_path^M
diff --git a/parameter.py b/parameter.py
index 0b59c5a..1191208 100755
--- a/parameter.py
+++ b/parameter.py
@@ -37,7 +37,7 @@ def get_parameters():
     parser.add_argument('--train', type=str2bool, default=True)
     parser.add_argument('--parallel', type=str2bool, default=False)
     parser.add_argument('--gpus', type=str, default='0', help='gpuids eg: 0,1,2,3  --parallel True  ')
-    parser.add_argument('--dataset', type=str, default='lsun', choices=['lsun', 'celeb','off'])
+    parser.add_argument('--dataset', type=str, default='lsun', choices=['lsun', 'celeb','off', 'rank'])^M
     parser.add_argument('--use_tensorboard', type=str2bool, default=False)
 
     # Path

Then a ranker.py script can be based on train.py:

import torch
import torch.nn as nn
from model_resnet import Discriminator
# from utils import *
from parameter import *
from data_loader import Data_Loader
from torch.backends import cudnn
import os

class Trainer(object):
    def __init__(self, data_loader, config):

        # Data loaders
        self.data_loader = data_loader

        # exact model and loss
        self.model = config.model

        # Model hyper-parameters
        self.imsize = config.imsize
        self.parallel = config.parallel
        self.gpus = config.gpus

        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.pretrained_model = config.pretrained_model

        self.dataset = config.dataset
        self.image_path = config.image_path
        self.version = config.version

        self.n_class = 1000 # config.n_class TODO
        self.chn = config.chn

        # Path
        self.model_save_path = os.path.join(config.model_save_path, self.version)

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.build_model()

        # Start with trained model
        self.load_pretrained_model()
        self.train()

    def train(self):

        self.D.train()
        # Data iterator
        data_iter = iter(self.data_loader)

        total_steps = self.data_loader.__len__()
        for step in range(0, total_steps):
            real_images, real_labels, real_paths = next(data_iter)
            real_labels = real_labels.to(self.device)
            real_images = real_images.to(self.device)
            d_out_real = self.D(real_images, real_labels)

            rankings = d_out_real.data.tolist()
            for i in range(0, len(real_paths)):
                print(real_paths[i], rankings[i])


    def load_pretrained_model(self):
        self.D.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_D.pth'.format(self.pretrained_model))))

    def build_model(self):
        # code_dim=100, n_class=1000
        self.D = Discriminator(self.n_class, chn=self.chn).to(self.device)
        if self.parallel:
            gpus = [int(i) for i in self.gpus.split(',')]
            self.D = nn.DataParallel(self.D, device_ids=gpus)

def main(config):
    # For fast training
    cudnn.benchmark = True

    # Data loader
    data_loader = Data_Loader(config.train, config.dataset, config.image_path, config.imsize,
                              config.batch_size, shuf=False)
    Trainer(data_loader.loader(), config)

if __name__ == '__main__':
    config = get_parameters()
    # print(config)
    main(config)

Example use:

$ python ranker.py --batch_size 1  --dataset rank --image_path /media/gwern/Data/danbooru2017/characters-faces/ --version 1kfaces --parallel True --gpus 0,1 --pretrained 627003
dataset 750356
/media/gwern/Data/danbooru2017/characters-faces/2k-tan/10139.jpg0.png 3.864508628845215
/media/gwern/Data/danbooru2017/characters-faces/2k-tan/104024.jpg0.png 3.384716272354126
/media/gwern/Data/danbooru2017/characters-faces/2k-tan/104024.jpg1.png 3.004866600036621
/media/gwern/Data/danbooru2017/characters-faces/2k-tan/1054281.jpg0.png 3.4808170795440674
/media/gwern/Data/danbooru2017/characters-faces/2k-tan/108431.jpg0.png 3.6894052028656006
/media/gwern/Data/danbooru2017/characters-faces/2k-tan/108805.jpg0.png 3.898812770843506
/media/gwern/Data/danbooru2017/characters-faces/2k-tan/111774.jpg0.png 2.8409836292266846
/media/gwern/Data/danbooru2017/characters-faces/2k-tan/114465.jpg0.png 3.883681058883667
/media/gwern/Data/danbooru2017/characters-faces/2k-tan/115171.jpg0.png 4.3082122802734375
...

Some cleaner built-in support would be good.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant