-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RuntimeError: Can't start new thread #280
Comments
Hey cgebbe, Do you think you could provide an example with a dummy model & synthetic data ? The model doesn't have to train at all. Also, you should try to use the StreamingDataLoader. It handles lot of things for correctness. Best, |
@tchaton : Thanks for the super quick reply and pointing me to the If not, I will try to give a minimal reproducing example. The issue has a rather high priority on our side, so I might even try a fix to limit the number of threads later. |
Hey @cgebbe. Sounds great. Let me know. We can do a pair debugging session too if you are open to it if you can't create a reproducible script for me investigate on my own. |
Hey @cgebbe. Any updates ? |
I started using the
Finally, there's a small example below which does NOT hang, but is otherwise pretty close to the original code. A pair debugging session would be amazing @tchaton , thanks a ton for the offer - just waiting for the okay from my supervisor. P.s.: Also successfully ran the """Script sketching out current training pipeline.
FILENAME=200_thread_problem.py
python $FILENAME fit --data=DataModule --model=TaskModule --trainer.fast_dev_run=100 --trainer.devices='[0,1,2]'
"""
import shutil
import litdata
import lightning as L
from lightning.pytorch.cli import LightningCLI
import torch
from torch.utils import data
import numpy as np
import PIL.Image
from pathlib import Path
import tqdm
from torch import nn
import logging
logger = logging.getLogger(__name__)
NUM_ITEMS_PER_DATASET = 49
NUM_DATASETS = 4
CACHE_DIR = Path("/scratch/dummy")
CACHE_DIR.mkdir(exist_ok=True, parents=True)
def _create_dataset(output_dir: Path):
def _random_images(index):
fake_image = torch.rand((3, 32, 32), dtype=torch.float32)
dct = {"index": index, "image": fake_image}
return dct
litdata.optimize(
fn=_random_images,
inputs=list(range(NUM_ITEMS_PER_DATASET)),
output_dir=str(output_dir),
num_workers=4,
chunk_bytes="64MB",
)
def _get_dataset_dirpaths(prefix: str):
for idx in tqdm.trange(NUM_DATASETS):
yield CACHE_DIR / prefix / str(idx)
def create_datasets():
for prefix in ["train", "val"]:
for dirpath in _get_dataset_dirpaths(prefix):
if dirpath.exists():
shutil.rmtree(dirpath)
_create_dataset(dirpath)
class CombinedDs(litdata.CombinedStreamingDataset):
def __init__(self, prefix: str):
lst = [
litdata.StreamingDataset(input_dir=str(dirpath), drop_last=False)
for dirpath in _get_dataset_dirpaths(prefix)
]
length_per_dataset = [len(ds) for ds in lst]
logger.info(f"{prefix=} has {length_per_dataset=}")
super().__init__(lst)
class DataModule(L.LightningDataModule):
def __init__(self):
super().__init__()
self.kwargs = dict(num_workers=20, pin_memory=True, batch_size=2)
def train_dataloader(self):
self.train_ds = CombinedDs("train")
return self._get_dataloader(self.train_ds)
def val_dataloader(self):
self.val_ds = CombinedDs("val")
return self._get_dataloader(self.val_ds)
def _get_dataloader(self, ds):
assert ds is not None
dl = litdata.StreamingDataLoader(ds, **self.kwargs)
# dl= data.DataLoader(ds, **self.kwargs)
logger.info(f"{len(ds)=}, {len(dl)=}")
return dl
class TaskModule(L.LightningModule):
# from https://lightning.ai/docs/pytorch/stable/starter/introduction.html
def __init__(self):
super().__init__()
self.model = nn.Conv2d(3, 3, kernel_size=3, padding=1)
def training_step(self, batch, batch_idx):
return self._calc_loss(batch)
def validation_step(self, batch, batch_idx):
return self._calc_loss(batch)
def _calc_loss(self, batch):
x = batch["image"]
y = self.model(x)
loss = torch.nn.functional.mse_loss(y, x)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def main():
logging.basicConfig(level=logging.INFO)
if 0:
create_datasets()
else:
LightningCLI()
main() |
Hey @cgebbe,
|
Hey @cgebbe. I made a new release. Could you try the latest release 0.2.20. Use |
Thanks again for the rapid answer @tchaton. You recommend setting I tried
I'll try to reproduce the hanging with the minimal example. side questions
|
Hey @cgebbe. That's exactly what I meant. It would hang if you set drop_last=False. Right now. the solution is for you to add more training samples in the dataset. |
We could also explore padding with duplicated data but this would increase litdata complexity quite a lot. |
🐛 Bug
I got the following error after training for 2h or 11h:
Code sample
Unfortunately I can't provide a minimal code sample, but the main points are:
CombinedStreamingDataset
with around ~7000 smallStreamingDataset
. The reason is that we need to specify several subsets of the 7000 datasets and do it this way (happy to learn about alternatives). While I know this is not optimal, it seemed to work fine at first (and also maxxed out GPU utilization)torch.utils.data.DataLoader
lightning.pytorch.cli.LightningCLI
Environment
conda
,pip
, source): uv pipAdditional Info
The text was updated successfully, but these errors were encountered: