Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/allenai/dolma into mixer-va…
Browse files Browse the repository at this point in the history
…lidator
  • Loading branch information
Masha Iureva authored and Masha Iureva committed Oct 11, 2024
2 parents 890de88 + 4615d34 commit 50763bd
Show file tree
Hide file tree
Showing 18 changed files with 265 additions and 53 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ jobs:
if: steps.cache-venv.outputs.cache-hit != 'true'
uses: actions/setup-python@v4
with:
python-version: "3.8"
python-version: "3.9"
architecture: "x64"

- name: Create a new Python environment & install maturin
Expand All @@ -109,7 +109,7 @@ jobs:
python -m venv .venv
source .venv/bin/activate
pip install -U pip
pip install maturin
pip install maturin==1.7.1
- name: Install dolma wheels
if: steps.cache-venv.outputs.cache-hit != 'true'
Expand Down Expand Up @@ -279,7 +279,7 @@ jobs:
if: "startsWith(github.ref, 'refs/tags/')"
needs: [build-linux, build-windows, build-macos, sdist]
steps:
- uses: actions/download-artifact@v4.1.7
- uses: actions/download-artifact@v3
with:
name: wheels
- name: Publish to PyPI
Expand All @@ -288,4 +288,5 @@ jobs:
MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
with:
command: upload
maturin-version: 1.7.1
args: --skip-existing *
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "dolma"
version = "1.0.9"
version = "1.1.0"
edition = "2021"
license = "Apache-2.0"

Expand Down
26 changes: 1 addition & 25 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,29 +1,5 @@
UNAME := $(shell uname)

ifeq ($(UNAME), Darwin)
OS_MESSAGE := "MacOS detected"
CMAKE_SETUP := "which cmake || brew install cmake"
PROTOBUF_SETUP := "which protoc || brew install protobuf"
OPENSSL_SETUP := "which openssl || brew install openssl"
else ifeq ($(UNAME), Linux)
OS_MESSAGE := "Linux detected"
CMAKE_SETUP := "which cmake || sudo apt-get install --yes build-essential cmake"
PROTOBUF_SETUP := "which protoc || sudo apt-get install --yes protobuf-compiler"
OPENSSL_SETUP := "which openssl || sudo apt-get install --yes libssl-dev"
else
OS_MESSAGE := "Unsupported OS; please install rust, cmake, protobuf, and openssl manually"
CMAKE_SETUP := ""
PROTOBUF_SETUP := ""
OPENSSL_SETUP := ""
endif

setup:
@echo "${OS_MESSAGE}: installing..."
$(shell "${CMAKE_SETUP}")
$(shell "${PROTOBUF_SETUP}")
$(shell "${OPENSSL_SETUP}")
which cargo || curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
which maturin || pip install maturin[patchelf]
@./setup.sh

publish:
maturin publish
Expand Down
2 changes: 1 addition & 1 deletion docs/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ Further, we override the number of processes to use to 96 using the `--processes
# the documents to mix; note how we use a glob pattern to match all documents
"documents": [
"wikipedia/v0/documents/*.gz",
]
],
# this is the directory where the output will be written
# note how the toolkit will try to create files of size ~1GB
"output": {
Expand Down
2 changes: 1 addition & 1 deletion docs/mixer.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ The following parameters are supported either via CLI (e.g. `dolma mix --paramet
|`streams[].span_replacement`|No| A list of objects specifying spans of text to be replaced. |
|`streams[].span_replacement[].span`|No| A json-path expression for an attribute that contains an array of spans. Each span should be list of length three: `[start, end, score]`. |
|`streams[].span_replacement[].min_score`|No| If the span score is less than this value, the span will not be replaced. |
|`streams[].span_replacement[].replacement`|No| The text that should be inserted in place of the span. Use `{}` to represent the original text. |
|`streams[].span_replacement[].replacement`|No| The text that should be inserted in place of the span. Use `{}` to represent the original text. Field selection from the document is also supported by prefixing a jq selector with `$`. Note: Escape a leading $ if you do not with to use jq selector pattern. |
|`work_dir.input`|No| Path to a local scratch directory where temporary input files can be placed. If not provided, Dolma will make one for you and delete it upon completion. |
|`work_dir.output`|No| Path to a local scratch directory where temporary output files can be placed. If not provided, Dolma will make one for you and delete it upon completion. |
|`processes`|No| Number of processes to use for mixing. By default 1 process is used. |
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
[project]
name = "dolma"
version = "1.0.12"
version = "1.1.0"
description = "Data filters"
license = { text = "Apache-2.0" }
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9"
dependencies = [
"anyascii>=0.3.2",
"blingfire==0.1.8",
Expand Down Expand Up @@ -196,12 +196,12 @@ exclude = '''
| tests/work
)
'''
target-version = ["py38", "py39", "py310", "py311", "py312"]
target-version = ["py39", "py310", "py311", "py312"]


[tool.isort]
profile = "black"
py_version = 38
py_version = 39
known_first_party = ["dolma"]
known_local_folder = ["tests", "python"]
extend_skip_glob = [
Expand All @@ -222,7 +222,7 @@ recursive = true
aggressive = 3

[tool.mypy]
python_version = "3.8"
python_version = "3.9"
ignore_missing_imports = true
no_site_packages = true
allow_redefinition = false
Expand Down
24 changes: 22 additions & 2 deletions python/dolma/cli/deduper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import fnmatch
import os
from contextlib import ExitStack
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -99,6 +101,13 @@ class DedupeConfig:
partition_index: Optional[int] = field(
default=0, help="The index of the partition being processed, in the range [0, num_partitions)."
)
file_partition: Optional[bool] = field(
default=False, help="Whether or not to partition at the document level (vs at the span level)"
)
document_dir: Optional[str] = field(
default="documents",
help="The folder in source paths to replace with 'attributes' to store results, if not 'documents'",
)


@dataclass
Expand Down Expand Up @@ -135,7 +144,6 @@ def run(cls, parsed_config: DeduperConfig):
logger = get_logger("tagger")

dict_config: Dict[str, Any] = {}

with ExitStack() as stack:
work_dirs = stack.enter_context(make_workdirs(parsed_config.work_dir))

Expand All @@ -146,6 +154,8 @@ def run(cls, parsed_config: DeduperConfig):
"min_words": parsed_config.dedupe.min_words,
"num_partitions": parsed_config.dedupe.num_partitions,
"partition_index": parsed_config.dedupe.partition_index,
"file_partition": parsed_config.dedupe.file_partition,
"document_dir": parsed_config.dedupe.document_dir,
}
try_name = parsed_config.dedupe.name if not om.is_missing(parsed_config.dedupe, "name") else None

Expand Down Expand Up @@ -182,7 +192,17 @@ def run(cls, parsed_config: DeduperConfig):
# perform some path validation to make sure we don't call the mixer with invalid config
total_matching_documents = 0
for document in parsed_config.documents:
dict_config.setdefault("documents", []).append(str(document))

if not any(
fnmatch.fnmatch(dict_config["dedupe"]["document_dir"], part) for part in document.split(os.sep)
):
raise DolmaConfigError(
f"Path ({document}) does not contain expected document directory: '/{dict_config['dedupe']['document_dir']}/'. "
)

doc = str(document)

dict_config.setdefault("documents", []).append(doc)

current_matching_documents = sum(1 for _ in glob_path(document))
if current_matching_documents == 0:
Expand Down
2 changes: 1 addition & 1 deletion python/dolma/cli/mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class SpanReplacementConfig:
default=None,
help="Maximum score for the span to be replaced. Either min_score or max_score must be specified.",
)
replacement: str = field(default="", help="Replacement for the span")
replacement: str = field(default="", help="Replacement config for the span(s).")
syntax: str = field(
default="jsonpath",
help="Syntax to use for filter expressions. Currently only JSONPath is supported. Defaults to JSONPath.",
Expand Down
28 changes: 28 additions & 0 deletions python/dolma/taggers/quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,34 @@
from ..core.registry import TaggerRegistry


@TaggerRegistry.add("dclm-oh-eli5")
class DclmQualityClassifier(BaseFastTextTagger):
MODEL_PATH = "https://huggingface.co/mlfoundations/fasttext-oh-eli5/resolve/main/openhermes_reddit_eli5_vs_rw_v2_bigram_200k_train.bin"

def __init__(self):
super().__init__(model_path=self.MODEL_PATH, model_mode=self.DOCUMENT_LEVEL_TAGGER)

def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]:
# Note: This slice should always be the entire document

# Clean the input text by joining all lines into a single string
text = " ".join(text_slice.doc.strip().splitlines())
pred = self.classifier.predict(text)

# Extract the predicted label and its probability
(pred_label, pred_prob) = pred
pred_label = pred_label[0]
probability_score = pred_prob[0]

# If the predicted label is 'CC', adjust the probability of it being 'Wikipedia'
if pred_label == "__label__cc":
probability_score = 1 - probability_score

label = pred_label.replace("__label__", "").replace("cc", "score").replace("hq", "score")

return [Prediction(label=label, score=probability_score)]


@TaggerRegistry.add("dolma17-quality")
class Dolma17QualityClassifier(BaseFastTextTagger):
MODEL_PATH = "https://dolma-artifacts.org/fasttext_models/dolma-1_7/cc_wiki_wikiref_sw_pes2o_adult_fakenews_math_books_openhermes.bin" # noqa: E501
Expand Down
7 changes: 4 additions & 3 deletions python/dolma/warc/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,10 @@ def process_single(
extension = extension.replace(".gz", "").replace(".warc", "") + ".jsonl.gz"
destination_path = join_path(prot, *base_dst[:-1], base_dst[-1] + extension)

with smart_open.open(source_path, "rb") as warc_file, smart_open.open(
destination_path, "wb"
) as output_file:
with (
smart_open.open(source_path, "rb") as warc_file,
smart_open.open(destination_path, "wb") as output_file,
):
it = ArchiveIterator(warc_file, record_types=WarcRecordType.response | WarcRecordType.warcinfo)
for record in it:
if record.record_type == WarcRecordType.warcinfo:
Expand Down
40 changes: 40 additions & 0 deletions setup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash
set -e

UNAME="$(uname)"
PLATFORM="$(uname -m)"

if [[ $UNAME == "Darwin" ]]; then
echo "MacOS detected..."
which cmake || brew install cmake
which protoc || brew install protobuf
which openssl || brew install openssl
elif [[ $UNAME == "Linux" ]]; then
echo "Linux detected..."
which cmake || sudo apt-get install --yes build-essential cmake
which protoc || sudo apt-get install --yes protobuf-compiler
which openssl || sudo apt-get install --yes libssl-dev
else
echo "Unsupported OS; please install rust, cmake, protobuf, maturin and openssl manually!"
exit 1
fi

which cargo || curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y

if [[ $PLATFORM == "x86_64" ]]; then
echo "x86_64 detected..."
which maturin || pip install maturin[patchelf]
fi

if [[ $PLATFORM = "aarch64" ]]; then
echo "aarch64 detected..."
which maturin || pip install maturin
fi

if [[ $PLATFORM = arm* ]]; then
echo "arm detected..."
which maturin || pip install maturin
else
echo "Unsupported platform; please install maturin manually"
exit 0
fi
34 changes: 32 additions & 2 deletions src/deduper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ use crate::s3_util;
use crate::shard::shard_config::{CompressionConfig, WorkDirConfig};
use crate::shard::{find_objects_matching_patterns, FileCache};
use crate::wimbd::tokens::tokenize;

use ahash::RandomState;
use deduper_config::*;
use std::hash::{BuildHasher, Hash, Hasher};

pub fn run(config: DeduperConfig) -> Result<u32, u32> {
let bloom_filter = BloomFilter::initialize(&config.bloom_filter).unwrap();
Expand All @@ -33,7 +34,20 @@ pub fn run(config: DeduperConfig) -> Result<u32, u32> {
let threadpool = ThreadPool::new(config.processes);
let failed_shard_count = AtomicU32::new(0);
let failed_shard_count_ref = Arc::new(failed_shard_count);
let hash_builder = RandomState::with_seeds(0, 1, 2, 3);

for p in paths {
let mut hasher = hash_builder.build_hasher();
p.hash(&mut hasher);
let hashed_path = hasher.finish();

if config.dedupe.file_partition.unwrap_or(false)
&& hashed_path % config.dedupe.num_partitions.unwrap_or(1)
!= config.dedupe.partition_index.unwrap_or(0)
{
continue;
}

let path = p.clone();
let work_dirs = config.work_dir.clone();
let dedupe = config.dedupe.clone();
Expand Down Expand Up @@ -121,10 +135,24 @@ fn write_attributes(
);
}

let document_key = dedupe_config
.document_dir
.unwrap_or(String::from("documents"));

let attrs_location = {
let attr_prefix = format!("/attributes/{}/", attr_key);
docs_location.replace("/documents/", &attr_prefix)
docs_location.replace(&format!("/{}/", &document_key), &attr_prefix)
};

if attrs_location == docs_location {
log::error!(
"{} does not contain {} . Not writing its attributes!",
docs_location,
&document_key
);
panic!("Attribute would be written to document location!");
}

let local_output = cache.prepare_output(&attrs_location, label_temp)?;
let mut num_processed = 0;
let mut num_observed = 0;
Expand Down Expand Up @@ -546,6 +574,8 @@ pub mod deduper_config {
pub skip_empty: Option<bool>,
pub num_partitions: Option<u64>,
pub partition_index: Option<u64>,
pub file_partition: Option<bool>,
pub document_dir: Option<String>,
}

#[derive(Serialize, Deserialize, Clone)]
Expand Down
Loading

0 comments on commit 50763bd

Please sign in to comment.