-
Notifications
You must be signed in to change notification settings - Fork 473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Documentation Improvements #745
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,7 @@ pyrightconfig.json | |
doc/_build/ | ||
*.swp | ||
.DS_Store | ||
|
||
readme_misc.md | ||
|
||
# python | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -17,23 +17,20 @@ | |||||
</a> | ||||||
</p> | ||||||
|
||||||
OLMo is a repository for training and using AI2's state-of-the-art open language models. | ||||||
It is built by scientists, for scientists. | ||||||
OLMo is a repository for training and using AI2's state-of-the-art open language models. It is designed by scientists, for scientists. | ||||||
|
||||||
## Installation | ||||||
|
||||||
First install [PyTorch](https://pytorch.org) according to the instructions specific to your operating system. | ||||||
First, install [PyTorch](https://pytorch.org) following the instructions specific to your operating system. | ||||||
|
||||||
To install from source (recommended for training/fine-tuning) run: | ||||||
For training and fine-tuning, we recommend installing from source: | ||||||
|
||||||
```bash | ||||||
git clone https://github.com/allenai/OLMo.git | ||||||
cd OLMo | ||||||
pip install -e .[all] | ||||||
``` | ||||||
|
||||||
Otherwise you can install the model code by itself directly from PyPI with: | ||||||
|
||||||
You can also install from PyPI with: | ||||||
```bash | ||||||
pip install ai2-olmo | ||||||
``` | ||||||
|
@@ -58,7 +55,7 @@ The core models in the OLMo family released so far are (all trained on the [Dolm | |||||
URLs to checkpoints at intermediate steps of the models' trainings can be found in the csv files under [`checkpoints/official/`](https://github.com/allenai/OLMo/blob/main/checkpoints/official). These 'directory' URLs cannot currently be directly accessed, but files within the directory are publicly accessible. These URLs can also be provided to the training script to resume training from the checkpoint (see [Training](#training)). Each checkpoint directory consists of: | ||||||
|
||||||
- `config.yaml`: the config at that training step. | ||||||
- `model.pt`, `optim.pt`, `train.pt`: model, optimizer and training state at that training step. | ||||||
- `model.safetensors`, `optim.safetensors`, `train.pt`: model, optimizer and training state at that training step. | ||||||
|
||||||
Details about the other types of OLMo checkpoints (including OLMo HF Transformers checkpoints) can be found in [Checkpoints.md](https://github.com/allenai/OLMo/blob/main/docs/Checkpoints.md). | ||||||
|
||||||
|
@@ -87,8 +84,7 @@ print(olmo_pipe("Language modeling is")) | |||||
``` | ||||||
|
||||||
### Inference on finetuned checkpoints | ||||||
|
||||||
If you finetune the model using the code in [Fine-tuning](#fine-tuning), you can use the conversion script to convert a native OLMo checkpoint to a Hugging Face-compatible checkpoint. | ||||||
After fine-tuning the model using the code in the [Fine-tuning](#fine-tuning) section, you can use the conversion script to convert a native OLMo checkpoint to a HuggingFace-compatible format. | ||||||
|
||||||
```bash | ||||||
python scripts/convert_olmo_to_hf_new.py --input_dir /path/to/olmo/checkpoint --output_dir /path/to/hf/checkpoint/ --tokenizer_json_path tokenizers/allenai_gpt-neox-olmo-dolma-v1_5.json | ||||||
|
@@ -100,48 +96,47 @@ python scripts/convert_olmo_to_hf_new.py --input_dir /path/to/olmo/checkpoint -- | |||||
olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-7B-0724-hf", torch_dtype=torch.float16, load_in_8bit=True) # requires bitsandbytes | ||||||
``` | ||||||
|
||||||
The quantized model is more sensitive to typing / cuda, so it is recommended to pass the inputs as inputs.input_ids.to('cuda') to avoid potential issues. | ||||||
The quantized model is sensitive to input types and CUDA handling. To avoid potential issues, we recommend explicitly converting input IDs to CUDA using: `inputs.input_ids.to('cuda')` | ||||||
|
||||||
## Reproducibility | ||||||
## Training | ||||||
|
||||||
### Training | ||||||
|
||||||
The configs used to train the official OLMo models are provided in the [`configs/official/`](https://github.com/allenai/OLMo/blob/main/configs/official) directory. | ||||||
|
||||||
Note that while the training and validation data is public and free to download, the paths to the data within those configs are pointed at a CloudFlare R2 bucket, which requires an API key for programmatic access. | ||||||
So in order to use any of these configs to reproduce a training run you'll first have to download the corresponding data to a location of your choosing and then update the paths in the config accordingly. | ||||||
|
||||||
You can derive the public HTTP URL from an R2 URL by replacing `r2://olmo-data` with `https://olmo-data.org`. | ||||||
For example, if the R2 data URL is: | ||||||
|
||||||
`r2://olmo-data/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00000.npy` | ||||||
|
||||||
then the corresponding public URL is: | ||||||
Install required packages: | ||||||
```bash | ||||||
pip3 install ai2-olmo wandb datasets torchmetrics scikit-learn | ||||||
``` | ||||||
|
||||||
`https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00000.npy` | ||||||
### Training from a Checkpoint | ||||||
|
||||||
Once you've updated the data paths in the config you can launch a training run via `torchrun`. For example, to launch the 1B model training on a single 8x GPU node, you would run: | ||||||
To continue training from a specific checkpoint: | ||||||
|
||||||
1. Download the checkpoint using the provided script. Checkpoints are listed in CSV files under `checkpoints/official/`: | ||||||
```bash | ||||||
torchrun --nproc_per_node=8 scripts/train.py configs/official/OLMo-1B.yaml | ||||||
python scripts/download_checkpoints.py [PATH_TO_CSV] --save-dir [SAVE_PATH] --step [STEP] | ||||||
``` | ||||||
|
||||||
You can use the same method to launch multi-node jobs as well. See [the documentation](https://pytorch.org/docs/stable/elastic/run.html) for `torchrun` to understand the additional arguments you'll need to configure the rendezvous backend / endpoint. | ||||||
Example: To download checkpoint at step 2000: | ||||||
```bash | ||||||
python scripts/download_checkpoints.py checkpoints/official/OLMo-1B.csv --save-dir ./checkpoints/ --step 2000 | ||||||
``` | ||||||
**Note**: All checkpoints in `checkpoints/official/` are unsharded files. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They aren't just files. Even in unsharded format, a checkpoint still consists of multiple files. |
||||||
|
||||||
To resume training from a checkpoint, you can pass its path (local or URL) | ||||||
to `scripts/train.py` with the `--load_path` arguments. For example, to resume training from step 1000 of the OLMo 1B run: | ||||||
2. Resume training using the downloaded checkpoint. You can specify either a local path or URL using the --load_path argument: For example, to resume training from step 2000 of the OLMo 1B run: | ||||||
|
||||||
```bash | ||||||
torchrun --nproc_per_node=8 scripts/train.py configs/official/OLMo-1B.yaml --load_path=https://olmo-checkpoints.org/ai2-llm/olmo-small/w1r5xfzt/step1000-unsharded | ||||||
torchrun --nproc_per_node=8 scripts/train.py configs/official/OLMo-1B.yaml --load_path=checkpoints/step2000 --save_folder=./new_checkpoints --run_name=olmo_test --save_overwrite | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without --save_overwrite, the program throws error. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only if the directory already exists |
||||||
``` | ||||||
The command above: | ||||||
- Loads the checkpoint from `checkpoints/step2000` | ||||||
- Saves new checkpoints to `./new_checkpoints` | ||||||
- Names the training run `olmo_test` in wandb. | ||||||
- Overwrites existing checkpoints in the save folder. | ||||||
|
||||||
### Inspecting training data | ||||||
|
||||||
You may be interested in inspecting the exact tokens that composed a particular batch during the training of one of the OLMo models. | ||||||
We provide tools to do this, but first you'll need to download the data as above (unless you have an R2 API key) and update the corresponding config accordingly. | ||||||
|
||||||
Then take note of the URL of the data order file you want, which can be found in the [Models Overview](#models-overview) table. For example, the data order file for the first epoch of the OLMo-7B model is [https://olmo-checkpoints.org/ai2-llm/olmo-medium/wvc30anm/train_data/global_indices.npy](https://olmo-checkpoints.org/ai2-llm/olmo-small/46zc5fly/train_data/global_indices.npy). | ||||||
To inspect the exact tokens used in training batches for OLMo models, first download the training data. If you don't have an R2 API key, use the public HTTP URLs and update your configuration file with the local data paths. After completing this setup, you can use the inspection tools to examine the training batches. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nobody external would ever have an R2 key. I think we can skip that part of the instructions. |
||||||
|
||||||
Find the data order file URL in the [Models Overview](#models-overview) table. For example, the OLMo-7B model's first epoch data order file is located at [https://olmo-checkpoints.org/ai2-llm/olmo-medium/wvc30anm/train_data/global_indices.npy](https://olmo-checkpoints.org/ai2-llm/olmo-small/46zc5fly/train_data/global_indices.npy). | ||||||
Once you have that you can use this snippet to inspect the data within a particular batch: | ||||||
|
||||||
```python | ||||||
|
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
import csv | ||
import os | ||
import requests | ||
from tqdm import tqdm | ||
import argparse | ||
from pathlib import Path | ||
from urllib.parse import urljoin | ||
|
||
def convert_to_r2_url(http_url): | ||
"""Convert HTTP URL to R2 URL format.""" | ||
if http_url.startswith('https://olmo-checkpoints.org/'): | ||
return http_url.replace('https://olmo-checkpoints.org/', 'r2://olmo-checkpoints/') | ||
return http_url | ||
|
||
def convert_to_public_url(r2_url): | ||
"""Convert R2 URL to public HTTP URL format.""" | ||
if r2_url.startswith('r2://olmo-checkpoints/'): | ||
return r2_url.replace('r2://olmo-checkpoints/', 'https://olmo-checkpoints.org/') | ||
return r2_url | ||
|
||
def download_file(url, save_path, chunk_size=8192): | ||
"""Download a file with progress bar.""" | ||
response = requests.get(url, stream=True) | ||
response.raise_for_status() | ||
total_size = int(response.headers.get('content-length', 0)) | ||
save_path.parent.mkdir(parents=True, exist_ok=True) | ||
|
||
with open(save_path, 'wb') as f: | ||
with tqdm(total=total_size, unit='B', unit_scale=True, desc=save_path.name) as pbar: | ||
for chunk in response.iter_content(chunk_size=chunk_size): | ||
if chunk: | ||
f.write(chunk) | ||
pbar.update(len(chunk)) | ||
|
||
def try_get_directory_listing(url): | ||
common_files = [ | ||
"config.yaml", | ||
"model.pt", | ||
"optim.pt", | ||
"train.pt", | ||
"model.safetensors", | ||
"optim.safetensors", | ||
] | ||
|
||
found_files = [] | ||
for pattern in common_files: | ||
test_url = urljoin(url.rstrip('/') + '/', pattern) | ||
try: | ||
response = requests.head(test_url) | ||
if response.status_code == 200: | ||
found_files.append(pattern) | ||
except requests.exceptions.RequestException: | ||
continue | ||
Comment on lines
+52
to
+53
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would you swallow these exceptions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, because you don't expect all files to be there? Then at least catch only 404 errors. |
||
|
||
return found_files | ||
|
||
def download_checkpoint(url, save_dir): | ||
"""Download all files from a checkpoint directory.""" | ||
r2_url = convert_to_r2_url(url) | ||
public_url = convert_to_public_url(r2_url) | ||
|
||
base_path = Path(save_dir) | ||
base_path.mkdir(parents=True, exist_ok=True) | ||
|
||
print(f"\nR2 URL: {r2_url}") | ||
print(f"Public URL: {public_url}") | ||
print(f"Saving to: {base_path}") | ||
|
||
print("Checking for available files...") | ||
available_files = try_get_directory_listing(public_url) | ||
|
||
if not available_files: | ||
print("No files found using common patterns. The directory might be empty or use different file patterns.") | ||
return | ||
|
||
for file in available_files: | ||
file_url = urljoin(public_url.rstrip('/') + '/', file) | ||
file_path = base_path / file | ||
|
||
try: | ||
print(f"\nDownloading: {file}") | ||
download_file(file_url, file_path) | ||
except requests.exceptions.RequestException as e: | ||
print(f"Error downloading {file}: {e}") | ||
continue | ||
Comment on lines
+83
to
+85
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, wait, don't just catch all exceptions. What about those that can be retried, like timeouts? |
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='Download OLMo checkpoints from CSV') | ||
parser.add_argument('csv_file', type=str, help='Path to the CSV file containing checkpoint URLs') | ||
parser.add_argument('--save-dir', type=str, default='./checkpoints', | ||
help='Base directory to save downloaded checkpoints') | ||
parser.add_argument('--step', type=str, help='Specific step number to download (optional)') | ||
parser.add_argument('--list-steps', action='store_true', help='List available step numbers and exit') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you have a tool that can perform multiple different actions, use subcommands. |
||
|
||
args = parser.parse_args() | ||
|
||
print(f"Reading CSV file: {args.csv_file}") | ||
|
||
with open(args.csv_file, 'r') as f: | ||
reader = csv.DictReader(f) | ||
urls = [(row['Step'], row['Checkpoint Directory']) for row in reader] | ||
|
||
if args.list_steps: | ||
print("\nAvailable steps:") | ||
for step, _ in urls: | ||
print(f"Step {step}") | ||
return | ||
|
||
if args.step: | ||
urls = [(step, url) for step, url in urls if step == args.step] | ||
if not urls: | ||
print(f"Error: Step {args.step} not found in the CSV file.") | ||
print("Use --list-steps to see available step numbers.") | ||
return | ||
|
||
print(f"Saving checkpoints to: {args.save_dir}") | ||
print("\nURL conversions:") | ||
for step, url in urls: | ||
r2_url = convert_to_r2_url(url) | ||
public_url = convert_to_public_url(r2_url) | ||
print(f"\nStep {step}:") | ||
print(f"Original URL: {url}") | ||
print(f"R2 URL: {r2_url}") | ||
print(f"Public URL: {public_url}") | ||
|
||
proceed = input("\nDo you want to proceed with the download? (y/n): ") | ||
if proceed.lower() != 'y': | ||
print("Download cancelled.") | ||
return | ||
Comment on lines
+126
to
+129
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, we don't ask for permission. The tools just do the thing. What if we'd want to script it? However, that means we have to make sure the tools never do anything dangerous by accident. |
||
|
||
for step, url in urls: | ||
save_path = os.path.join(args.save_dir, f"step{step}") | ||
try: | ||
download_checkpoint(url, save_path) | ||
except Exception as e: | ||
print(f"Error during download of step {step}: {e}") | ||
Comment on lines
+131
to
+136
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think anyone will ever want to download all steps? That's a lot of data. I think it's better if we give one command to list steps, and another to download one step, and let them deal with the rest. |
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -268,9 +268,10 @@ def dummy_init_fn(module: torch.nn.Module) -> None: | |
) | ||
cfg.save_num_unsharded_checkpoints_to_keep = cfg.save_num_checkpoints_to_keep | ||
elif cfg.distributed_strategy == DistributedStrategy.fsdp: | ||
checkpoint_type = ( | ||
CheckpointType.sharded if cfg.save_num_checkpoints_to_keep != 0 else CheckpointType.unsharded | ||
) | ||
# checkpoint_type = ( | ||
# CheckpointType.sharded if cfg.save_num_checkpoints_to_keep != 0 else CheckpointType.unsharded | ||
# ) | ||
checkpoint_type = CheckpointType.unsharded | ||
Comment on lines
+271
to
+274
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's this? |
||
else: | ||
raise NotImplementedError(f"Distributed strategy {cfg.distributed_strategy} not supported yet!") | ||
|
||
|
@@ -297,7 +298,9 @@ def dummy_init_fn(module: torch.nn.Module) -> None: | |
cfg.load_path, | ||
load_optimizer_state=not cfg.reset_optimizer_state, | ||
load_trainer_state=not cfg.reset_trainer_state, | ||
sharded_checkpointer=cfg.load_path_sharded_checkpointer, | ||
# sharded_checkpointer=cfg.load_path_sharded_checkpointer, | ||
sharded_checkpointer= False, | ||
checkpoint_type=CheckpointType.unsharded | ||
Comment on lines
-300
to
+303
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question here |
||
) | ||
log.info("Checkpoint successfully loaded") | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was my read_me with grammatical mistakes. I will remove it from .gitignore.