Skip to content

Commit

Permalink
Tweak rate computation to not include iterator creation
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed May 13, 2024
1 parent 1fbea0a commit 028b228
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 12 deletions.
7 changes: 7 additions & 0 deletions benchmarks/torchvision_ddp/activator
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash

venv="$1"
shift

source "$venv"/bin/activate
exec "$@"
50 changes: 38 additions & 12 deletions benchmarks/torchvision_ddp/main.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python

import argparse
import os
import sys
Expand Down Expand Up @@ -53,10 +55,10 @@ def __init__(
) -> None:
self.rank = gpu_id
self.device = accelerator.fetch_device(gpu_id)
self.model = model.to(gpu_id)
self.model = model.to(self.device)
self.train_data = train_data
self.optimizer = optimizer
# self.model = FSDP(model, device_id=gpu_id)
# self.model = FSDP(model, device_id=self.device)
self.model = DDP(model, device_ids=[self.device])
self.world_size = world_size
self.data_file = SmuggleWriter(sys.stdout)
Expand All @@ -67,6 +69,7 @@ def print(self, *args, **kwargs):

def _run_batch(self, source, targets):
with accelerator.amp.autocast(dtype=torch.bfloat16):

self.optimizer.zero_grad()
output = self.model(source)
loss = F.cross_entropy(output, targets)
Expand All @@ -86,25 +89,45 @@ def toiterator(loader):

sample_count = 0
losses = []
events = []

self.train_data.sampler.set_epoch(epoch)
loader = timeiterator(voir.iterate("train", toiterator(self.train_data), True))

start_event = accelerator.Event(enable_timing=True)
start_event.record()

for source, targets in loader:
end_event = accelerator.Event(enable_timing=True)

with timeit("batch"):
source = source.to(self.device)
targets = targets.to(self.device)

sample_count += len(source)
n = len(source)
sample_count += n

loss = self._run_batch(source, targets)
losses.append(loss)

end_event.record()
events.append((start_event, end_event, n))
start_event = end_event

for start, end, n in events:
end.synchronize()
elapsed = start.elapsed_time(end) / 1000
rate = (n * self.world_size) / elapsed
self.log({
"task": "train",
"rate": rate,
"units": "items/s",
})

total_count = torch.tensor([sample_count], dtype=torch.int64, device=self.device)
dist.reduce(total_count, dst=0)
accelerator.synchronize()
loss = sum([l.item() for l in losses]) / len(losses)

loss = sum([l.item() for l in losses]) / len(losses)
return total_count.item(), loss

def train(self, max_epochs: int):
Expand All @@ -124,11 +147,11 @@ def log(self, data):
def perf(self, loss, total_count, timer):
if self.rank == 0:
self.log({"task": "train", "loss": loss})
self.log({
"task": "train",
"rate": total_count / (timer.end - timer.start),
"units": "items/s",
})
# self.log({
# "task": "train",
# "rate": total_count / (timer.end - timer.start),
# "units": "items/s",
# })


def image_transforms():
Expand All @@ -147,7 +170,7 @@ def prepare_dataloader(dataset: Dataset, args):
return DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.num_workers if args.noio else 0,
num_workers=args.num_workers if not args.noio else 0,
pin_memory=not args.noio,
shuffle=False,
sampler=DistributedSampler(dataset)
Expand Down Expand Up @@ -207,11 +230,14 @@ def worker_main(rank: int, world_size: int, args):
print(f"<<< rank: {rank}")
except Exception as err:
print(err)
finally:
if rank == 0:
show_timings(True)


def main():
parser = argparse.ArgumentParser(description='simple distributed training job')
parser.add_argument('--batch_size', default=512, type=int, help='Input batch size on each device (default: 32)')
parser.add_argument('--batch-size', default=512, type=int, help='Input batch size on each device (default: 32)')
parser.add_argument(
"--model", type=str, help="torchvision model name", default="resnet50"
)
Expand Down
26 changes: 26 additions & 0 deletions config/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ _torchvision:
--epochs: 50
--num-workers: 8


_torchvision_ddp:
inherits: _defaults
definition: ../benchmarks/torchvision_ddp
group: torchvision
install_group: torch
plan:
method: njobs
n: 1
argv:
--epochs: 10
--num-workers: 8

_flops:
inherits: _defaults
definition: ../benchmarks/flops
Expand Down Expand Up @@ -193,6 +206,19 @@ resnet50-noio:
--synthetic-data: true


resnet152-ddp:
inherits: _torchvision_ddp
tags:
- vision
- classification
- convnet
- resnet

argv:
--model: resnet152
--batch-size: 256
--num-workers: 8

efficientnet_b4:
inherits: _torchvision

Expand Down

0 comments on commit 028b228

Please sign in to comment.