You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
One useful trick for data cleaning is taking a trained Discriminator and using it to find the 'worst' samples and either manually reviewing them or automatically deleting them.
I modified the training script to support this by ranking a folder dataset using a D model. The default dataloader doesn't preserve image paths, so an additional dataset has to be added:
diff --git a/data_loader.py b/data_loader.py
index 736362c..c4528a8 100755
--- a/data_loader.py+++ b/data_loader.py@@ -29,7 +29,7 @@ class Data_Loader():
transforms = self.transform(True, True, True, False)
dataset = dsets.LSUN(self.path, classes=classes, transform=transforms)
return dataset
-+^M
def load_imagenet(self):
transforms = self.transform(True, True, True, True)
dataset = dsets.ImageFolder(self.path+'/imagenet', transform=transforms)
@@ -42,9 +42,15 @@ class Data_Loader():
def load_off(self):
transforms = self.transform(True, True, True, False)
dataset = dsets.ImageFolder(self.path, transform=transforms)
return dataset
+ def load_rank(self):^M+ transforms = self.transform(True, True, True, False)^M+ dataset = ImageFolderWithPaths(self.path, transform=transforms)^M+ return dataset^M+^M
def loader(self):
if self.dataset == 'lsun':
dataset = self.load_lsun()
@@ -54,6 +60,8 @@ class Data_Loader():
dataset = self.load_celeb()
elif self.dataset == 'off':
dataset = self.load_off()
+ elif self.dataset == 'rank':^M+ dataset = self.load_rank()^M
print('dataset',len(dataset))
loader = torch.utils.data.DataLoader(dataset=dataset,
@@ -63,3 +71,18 @@ class Data_Loader():
drop_last=True)
return loader
+^M+class ImageFolderWithPaths(dsets.ImageFolder):^M+ """Custom dataset that includes image file paths. Extends^M+ torchvision.datasets.ImageFolder^M+ """^M+^M+ # override the __getitem__ method. this is the method dataloader calls^M+ def __getitem__(self, index):^M+ # this is what ImageFolder normally returns^M+ original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)^M+ # the image file path^M+ path = self.imgs[index][0]^M+ # make a new tuple that includes original and the path^M+ tuple_with_path = (original_tuple + (path,))^M+ return tuple_with_path^Mdiff --git a/parameter.py b/parameter.py
index 0b59c5a..1191208 100755
--- a/parameter.py+++ b/parameter.py@@ -37,7 +37,7 @@ def get_parameters():
parser.add_argument('--train', type=str2bool, default=True)
parser.add_argument('--parallel', type=str2bool, default=False)
parser.add_argument('--gpus', type=str, default='0', help='gpuids eg: 0,1,2,3 --parallel True ')
- parser.add_argument('--dataset', type=str, default='lsun', choices=['lsun', 'celeb','off'])+ parser.add_argument('--dataset', type=str, default='lsun', choices=['lsun', 'celeb','off', 'rank'])^M
parser.add_argument('--use_tensorboard', type=str2bool, default=False)
# Path
Then a ranker.py script can be based on train.py:
importtorchimporttorch.nnasnnfrommodel_resnetimportDiscriminator# from utils import *fromparameterimport*fromdata_loaderimportData_Loaderfromtorch.backendsimportcudnnimportosclassTrainer(object):
def__init__(self, data_loader, config):
# Data loadersself.data_loader=data_loader# exact model and lossself.model=config.model# Model hyper-parametersself.imsize=config.imsizeself.parallel=config.parallelself.gpus=config.gpusself.batch_size=config.batch_sizeself.num_workers=config.num_workersself.pretrained_model=config.pretrained_modelself.dataset=config.datasetself.image_path=config.image_pathself.version=config.versionself.n_class=1000# config.n_class TODOself.chn=config.chn# Pathself.model_save_path=os.path.join(config.model_save_path, self.version)
self.device=torch.device('cuda'iftorch.cuda.is_available() else'cpu')
self.build_model()
# Start with trained modelself.load_pretrained_model()
self.train()
deftrain(self):
self.D.train()
# Data iteratordata_iter=iter(self.data_loader)
total_steps=self.data_loader.__len__()
forstepinrange(0, total_steps):
real_images, real_labels, real_paths=next(data_iter)
real_labels=real_labels.to(self.device)
real_images=real_images.to(self.device)
d_out_real=self.D(real_images, real_labels)
rankings=d_out_real.data.tolist()
foriinrange(0, len(real_paths)):
print(real_paths[i], rankings[i])
defload_pretrained_model(self):
self.D.load_state_dict(torch.load(os.path.join(
self.model_save_path, '{}_D.pth'.format(self.pretrained_model))))
defbuild_model(self):
# code_dim=100, n_class=1000self.D=Discriminator(self.n_class, chn=self.chn).to(self.device)
ifself.parallel:
gpus= [int(i) foriinself.gpus.split(',')]
self.D=nn.DataParallel(self.D, device_ids=gpus)
defmain(config):
# For fast trainingcudnn.benchmark=True# Data loaderdata_loader=Data_Loader(config.train, config.dataset, config.image_path, config.imsize,
config.batch_size, shuf=False)
Trainer(data_loader.loader(), config)
if__name__=='__main__':
config=get_parameters()
# print(config)main(config)
One useful trick for data cleaning is taking a trained Discriminator and using it to find the 'worst' samples and either manually reviewing them or automatically deleting them.
I modified the training script to support this by ranking a folder dataset using a D model. The default dataloader doesn't preserve image paths, so an additional dataset has to be added:
Then a
ranker.py
script can be based ontrain.py
:Example use:
Some cleaner built-in support would be good.
The text was updated successfully, but these errors were encountered: