Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend functionality of Wandb Config Diff script #687

Merged
merged 12 commits into from
Oct 30, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added ability to try loading latest checkpoint from save folder using `--try_load_latest_save`.
- Added support for flash attention and gradient checkpointing to `hf_olmo`.
- Added to `scripts.compare_wandb_configs.py` the ability to more easily compare differences in data mixes and evaluation tasks.
- Added `effective_n_kv_heads` to OLMoConfig for hacky VLLM support.

## [v0.5.0](https://github.com/allenai/OLMo/releases/tag/v0.5.0) - 2024-08-26
Expand Down
20 changes: 18 additions & 2 deletions olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,12 +841,28 @@ def get_bytes_range(self, index: int, length: int) -> bytes:
return response["Body"].read()


def flatten_dict(dictionary, parent_key="", separator="."):
def flatten_dict(dictionary, parent_key="", separator=".", include_lists=False):
"""
Flatten a nested dictionary into a single-level dictionary.

Args:
dictionary (dict): The nested dictionary to be flattened.
parent_key (str, optional): The parent key to be prepended to the keys of the flattened dictionary. Defaults to "".
separator (str, optional): The separator to be used between the parent key and the keys of the flattened dictionary. Defaults to ".".
include_lists (bool, optional): Whether to convert lists to dictionaries with integer keys. Defaults to False.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we ever want to turn this off?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not for now, but seems fine to extend this in case someone wants to add different logic for dealing w list-valued config params?


Returns:
dict: The flattened dictionary.

"""
d: Dict[str, Any] = {}
for key, value in dictionary.items():
new_key = parent_key + separator + key if parent_key else key
# convert lists to dict with key <int>
if isinstance(value, list) and include_lists:
value = {f"{i}": v for i, v in enumerate(value)}
if isinstance(value, MutableMapping):
d.update(**flatten_dict(value, new_key, separator=separator))
d.update(**flatten_dict(value, new_key, separator=separator, include_lists=include_lists))
else:
d[new_key] = value
return d
135 changes: 111 additions & 24 deletions scripts/compare_wandb_configs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
"""

Examples:
Comparing Peteish7 to OLMoE
- python scripts/compare_wandb_configs.py https://wandb.ai/ai2-llm/olmo-medium/runs/cej4ya39 https://wandb.ai/ai2-llm/olmoe/runs/rzsn9tlc

Comparing Peteish7 to Amberish7
- python scripts/compare_wandb_configs.py https://wandb.ai/ai2-llm/olmo-medium/runs/cej4ya39 https://wandb.ai/ai2-llm/olmo-medium/runs/ij4ls6v2


"""

import logging
import os
import re
from collections import Counter

import click

Expand All @@ -24,6 +38,47 @@ def parse_run_path(run_path: str) -> str:
raise ValueError(f"Could not parse '{run_path}'")


def print_keys_with_differences(left_config, right_config):
s_left = ""
left_only_keys = left_config.keys() - right_config.keys()
if len(left_only_keys) > 0:
s_left += "Settings only in left:\n"
s_left += "\n".join(f"\t{k}: {left_config[k]}" for k in sorted(left_only_keys)) + "\n"

s_right = ""
right_only_keys = right_config.keys() - left_config.keys()
if len(right_only_keys) > 0:
s_right += "Settings only in right:\n"
s_right += "\n".join(f"\t{k}: {right_config[k]}" for k in sorted(right_only_keys)) + "\n"

s_shared = ""
keys_with_differences = {
k for k in left_config.keys() & right_config.keys() if left_config[k] != right_config[k]
}
if len(keys_with_differences) > 0:
for k in sorted(keys_with_differences):
s_shared += f"{k}\n\t{left_config[k]}\n" + f"\t{right_config[k]}\n\n"

if (s_left or s_right) and not s_shared:
s = s_left + "=" * 50 + "\n" + s_right + "=" * 50 + "\n" + "No differences in shared settings.\n"
else:
s = s_left + "=" * 50 + "\n" + s_right + "=" * 50 + "\n" + s_shared
print(s.strip())


def print_data_differences(left_data_paths: Counter, right_data_paths: Counter):
print("===== Data Paths for left config:\n")
simplified_left_data_paths = {path: count for path, count in left_data_paths.items()}
for path, num_files in simplified_left_data_paths.items():
print(f"\t{path}: {num_files}")
print("\n\n")

print("===== Data Paths for right config:\n")
simplified_right_data_paths = {path: count for path, count in right_data_paths.items()}
for path, num_files in simplified_right_data_paths.items():
print(f"\t{path}: {num_files}")


@click.command()
@click.argument(
"left_run_path",
Expand All @@ -43,30 +98,62 @@ def main(
left_run = api.run(parse_run_path(left_run_path))
right_run = api.run(parse_run_path(right_run_path))

left_config = flatten_dict(left_run._attrs["rawconfig"])
right_config = flatten_dict(right_run._attrs["rawconfig"])

left_only_keys = left_config.keys() - right_config.keys()
if len(left_only_keys) > 0:
print("Settings only in left:")
print("\n".join(f"\t{k}: {left_config[k]}" for k in sorted(left_only_keys)))
print()

right_only_keys = right_config.keys() - left_config.keys()
if len(right_only_keys) > 0:
print("Settings only in right:")
print("\n".join(f"\t{k}: {right_config[k]}" for k in sorted(right_only_keys)))
print()

keys_with_differences = {
k for k in left_config.keys() & right_config.keys() if left_config[k] != right_config[k]
}
if len(keys_with_differences) > 0:
if len(left_only_keys) > 0 or len(right_only_keys) > 0:
print("Settings with differences:")
print("\n".join(f"{k}\n\t{left_config[k]}\n\t{right_config[k]}\n" for k in sorted(keys_with_differences)))
else:
print("No differences in shared settings.")
left_config_raw = left_run._attrs["rawconfig"]
right_config_raw = right_run._attrs["rawconfig"]

# flattening the dict will make diffs easier
left_config = flatten_dict(left_config_raw)
right_config = flatten_dict(right_config_raw)

# there are 2 specific fields in config that are difficult to diff:
# "evaluators" is List[Dict]
# "data.paths" is List[str]
# let's handle each of these directly.

# first, data.paths can be grouped and counted.
left_data_paths = Counter([os.path.dirname(path) for path in left_config["data.paths"]])
right_data_paths = Counter([os.path.dirname(path) for path in right_config["data.paths"]])
del left_config["data.paths"]
del right_config["data.paths"]

# next, evaluators can be added to the flat dict with unique key per evaluator
# also, each evaluator can also have a 'data.paths' field which needs collapsing
def _simplify_evaluator(evaluator):
evaluator = flatten_dict(evaluator)
if evaluator["data.paths"]:
evaluator["data.paths"] = Counter([os.path.dirname(path) for path in evaluator["data.paths"]])
return evaluator

def _simplify_evaluators(evaluators):
simplified_evaluators = {}
for evaluator in evaluators:
new_key = (".".join(["evaluators" + "." + evaluator["type"] + "." + evaluator["label"]])).upper()
simplified_evaluators[new_key] = _simplify_evaluator(evaluator)
return simplified_evaluators

left_evaluators = flatten_dict(_simplify_evaluators(left_config["evaluators"]), separator="___")
right_evaluators = flatten_dict(_simplify_evaluators(right_config["evaluators"]), separator="___")
del left_config["evaluators"]
del right_config["evaluators"]

print(
f"==================== Config differences between {left_run_path} and {right_run_path} ====================\n\n"
)

# print config differences
print("==================== Param differences ====================\n\n")
print_keys_with_differences(left_config=left_config, right_config=right_config)
print("============================================================= \n\n")

# print data differences
print("==================== Data Differences ====================\n\n")
print_data_differences(left_data_paths, right_data_paths)
print("============================================================= \n\n")

# print eval differences
print("==================== Eval Differences ====================\n\n")
print_keys_with_differences(left_config=left_evaluators, right_config=right_evaluators)
print("============================================================= \n\n")


if __name__ == "__main__":
Expand Down
28 changes: 28 additions & 0 deletions tests/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,31 @@ def test_dir_is_empty(tmp_path):
# Should return false if dir contains anything, even hidden files.
(dir / ".foo").touch()
assert not util.dir_is_empty(dir)


def test_flatten_dict():
# basic flattening
test_dict = {"a": 0, "b": {"e": 5, "f": 1}, "c": 2}
assert util.flatten_dict(test_dict) == {"a": 0, "b.e": 5, "b.f": 1, "c": 2}

# Should flatten nested dicts into a single dict with dotted keys.
test_dict_with_list_of_dicts = {
"a": 0,
"b": {"e": [{"x": {"z": [222, 333]}}, {"y": {"g": [99, 100]}}], "f": 1},
"c": 2,
}
assert util.flatten_dict(test_dict_with_list_of_dicts) == {
"a": 0,
"b.e": [{"x": {"z": [222, 333]}}, {"y": {"g": [99, 100]}}], # doesnt get flattened
"b.f": 1,
"c": 2,
}
assert util.flatten_dict(test_dict_with_list_of_dicts, include_lists=True) == {
"a": 0,
"b.e.0.x.z.0": 222,
"b.e.0.x.z.1": 333,
"b.e.1.y.g.0": 99,
"b.e.1.y.g.1": 100,
"b.f": 1,
"c": 2,
}
Loading