Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize torchvision model #151

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 126 additions & 22 deletions benchmarks/torchvision/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ def scaling(enable):
yield


def train_epoch(model, criterion, optimizer, loader, device, scaler=None):
model.train()
def train_epoch(args, model, criterion, optimizer, loader, device, scaler=None):
transform = dict(device=device)
if "channel_last" in args.optim:
transform["memory_format"] = torch.channels_last

for inp, target in voir.iterate("train", loader, True):
inp = inp.to(device)
inp = inp.to(**transform)
target = target.to(device)
optimizer.zero_grad()
optimizer.zero_grad(set_to_none="set_grad_none" in args.optim)
with scaling(scaler is not None):
output = model(inp)
loss = criterion(output, target)
Expand All @@ -61,6 +64,105 @@ def train_epoch(model, criterion, optimizer, loader, device, scaler=None):
optimizer.step()



def model_optimizer(args, model, device):
model.train()

if "channel_last" in args.optim:
model = model.to(memory_format=torch.channels_last)

if "trace" in args.optim:
input = torch.randn((args.batch_size, 3, 224, 224)).to(device)
model = torch.jit.trace(model, input)
return model, model.parameters()

if "inductor" in args.optim:
from functorch import make_functional_with_buffers
from functorch.compile import make_boxed_func

model, params, buffers = make_functional_with_buffers(model)

model = make_boxed_func(model)

# backend , nvprims_nvfuser, cnvprims_nvfuser
model = torch.compile(model, backend="inductor")

def forward(*args):
return model((params, buffers, *args))

return forward, params

return model, model.parameters()


def dali(args, images_dir):
from nvidia.dali.pipeline import pipeline_def
import nvidia.dali.types as types
import nvidia.dali.fn as fn
from nvidia.dali.plugin.pytorch import DALIGenericIterator

@pipeline_def(num_threads=args.num_workers, device_id=0)
def get_dali_pipeline():
images, labels = fn.readers.file(
file_root=images_dir,
random_shuffle=True,
name="Reader",
)
# decode data on the GPU
images = fn.decoders.image_random_crop(
images,
device="mixed",
output_type=types.RGB,
)
# the rest of processing happens on the GPU as well
images = fn.resize(images, resize_x=256, resize_y=256)
images = fn.crop_mirror_normalize(
images,
crop_h=224,
crop_w=224,
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
mirror=fn.random.coin_flip()
)
return images, labels

train_data = DALIGenericIterator(
[get_dali_pipeline(batch_size=args.batch_size)],
['data', 'label'],
reader_name='Reader'
)

def iter():
for _ in range(args.epochs):
for data in train_data:
x, y = data[0]['data'], data[0]['label']
yield x, torch.squeeze(y, dim=1).type(torch.LongTensor)

yield from iter()


def dataloader(args, model, device):
if args.loader == "dali":
return dali(args, args.data)

if args.data:
train = datasets.ImageFolder(os.path.join(args.data, "train"), data_transforms)
return torch.utils.data.DataLoader(
train,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
)
else:
return SyntheticData(
model=model,
device=device,
batch_size=args.batch_size,
n=1000,
fixed_batch=args.fixed_batch,
)


class SyntheticData:
def __init__(self, model, device, batch_size, n, fixed_batch):
self.n = n
Expand Down Expand Up @@ -89,6 +191,21 @@ def main():
metavar="N",
help="input batch size for training (default: 16)",
)
parser.add_argument(
"--loader",
type=str,
default="pytorch",
choices=["pytorch", "dali"],
help="Dataloader backend",
)
parser.add_argument(
"--optim",
type=str,
default="",
nargs="+",
choices=["trace", "inductor", "script", "channel_last", "set_grad_none"],
help="Optimization to enable",
)
parser.add_argument(
"--model", type=str, help="torchvision model name", required=True
)
Expand Down Expand Up @@ -178,27 +295,14 @@ def main():

model = getattr(tvmodels, args.model)()
model.to(device)

model, params = model_optimizer(args, model, device)

criterion = nn.CrossEntropyLoss().to(device)

optimizer = torch.optim.SGD(model.parameters(), args.lr)
optimizer = torch.optim.SGD(params, args.lr)

if args.data:
train = datasets.ImageFolder(os.path.join(args.data, "train"), data_transforms)
train_loader = torch.utils.data.DataLoader(
train,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
)
else:
train_loader = SyntheticData(
model=model,
device=device,
batch_size=args.batch_size,
n=1000,
fixed_batch=args.fixed_batch,
)
train_loader = dataloader(args, model, device)

if is_fp16_allowed(args):
scaler = torch.cuda.amp.GradScaler()
Expand All @@ -213,7 +317,7 @@ def main():
if not args.no_stdout:
print(f"Begin training epoch {epoch}/{args.epochs}")
train_epoch(
model, criterion, optimizer, train_loader, device, scaler=scaler
args, model, criterion, optimizer, train_loader, device, scaler=scaler
)


Expand Down