Skip to content

Commit

Permalink
[gguf & st] parse shard filenames in typed function (#631)
Browse files Browse the repository at this point in the history
follow up to
#627 (review)

> why not a typed function fn(filename) that returns all the data in a
typed manner instead of the regex 🙈 .
> Currently in the next version of HF.js we can change the regex to
something else, removing or renaming the groups, and on moon's side we'd
have no clue. No compiler warning or anything.
  • Loading branch information
Mishig authored Apr 17, 2024
1 parent 08f6eaa commit c0a43bc
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 13 deletions.
11 changes: 5 additions & 6 deletions packages/gguf/src/gguf.spec.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { describe, expect, it } from "vitest";
import { GGMLQuantizationType, RE_GGUF_SHARD_FILE, gguf } from "./gguf";
import { GGMLQuantizationType, gguf, parseGgufShardFile } from "./gguf";

const URL_LLAMA = "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/191239b/llama-2-7b-chat.Q2_K.gguf";
const URL_MISTRAL_7B =
Expand Down Expand Up @@ -223,11 +223,10 @@ describe("gguf", () => {

it("should detect sharded gguf filename", async () => {
const ggufPath = "grok-1/grok-1-q4_0-00003-of-00009.gguf"; // https://huggingface.co/ggml-org/models/blob/fcf344adb9686474c70e74dd5e55465e9e6176ef/grok-1/grok-1-q4_0-00003-of-00009.gguf
const match = ggufPath.match(RE_GGUF_SHARD_FILE);
const ggufShardFileInfo = parseGgufShardFile(ggufPath);

expect(RE_GGUF_SHARD_FILE.test(ggufPath)).toEqual(true);
expect(match?.groups?.prefix).toEqual("grok-1/grok-1-q4_0");
expect(match?.groups?.shard).toEqual("00003");
expect(match?.groups?.total).toEqual("00009");
expect(ggufShardFileInfo?.prefix).toEqual("grok-1/grok-1-q4_0");
expect(ggufShardFileInfo?.shard).toEqual("00003");
expect(ggufShardFileInfo?.total).toEqual("00009");
});
});
18 changes: 18 additions & 0 deletions packages/gguf/src/gguf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@ export { GGUF_QUANT_DESCRIPTIONS } from "./quant-descriptions";
export const RE_GGUF_FILE = /\.gguf$/;
export const RE_GGUF_SHARD_FILE = /^(?<prefix>.*?)-(?<shard>\d{5})-of-(?<total>\d{5})\.gguf$/;

export interface GgufShardFileInfo {
prefix: string;
shard: string;
total: string;
}

export function parseGgufShardFile(filename: string): GgufShardFileInfo | null {
const match = RE_GGUF_SHARD_FILE.exec(filename);
if (match && match.groups) {
return {
prefix: match.groups["prefix"],
shard: match.groups["shard"],
total: match.groups["total"],
};
}
return null;
}

const isVersion = (version: number): version is Version => version === 1 || version === 2 || version === 3;

/**
Expand Down
13 changes: 6 additions & 7 deletions packages/hub/src/lib/parse-safetensors-metadata.spec.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { assert, it, describe } from "vitest";
import { RE_SAFETENSORS_SHARD_FILE, parseSafetensorsMetadata } from "./parse-safetensors-metadata";
import { parseSafetensorsMetadata, parseSafetensorsShardFile } from "./parse-safetensors-metadata";
import { sum } from "../utils/sum";

describe("parseSafetensorsMetadata", () => {
Expand Down Expand Up @@ -112,12 +112,11 @@ describe("parseSafetensorsMetadata", () => {

it("should detect sharded safetensors filename", async () => {
const safetensorsFilename = "model_00005-of-00072.safetensors"; // https://huggingface.co/bigscience/bloom/blob/4d8e28c67403974b0f17a4ac5992e4ba0b0dbb6f/model_00005-of-00072.safetensors
const match = safetensorsFilename.match(RE_SAFETENSORS_SHARD_FILE);
const safetensorsShardFileInfo = parseSafetensorsShardFile(safetensorsFilename);

assert.strictEqual(RE_SAFETENSORS_SHARD_FILE.test(safetensorsFilename), true);
assert.strictEqual(match?.groups?.prefix, "model_");
assert.strictEqual(match?.groups?.basePrefix, "model");
assert.strictEqual(match?.groups?.shard, "00005");
assert.strictEqual(match?.groups?.total, "00072");
assert.strictEqual(safetensorsShardFileInfo?.prefix, "model_");
assert.strictEqual(safetensorsShardFileInfo?.basePrefix, "model");
assert.strictEqual(safetensorsShardFileInfo?.shard, "00005");
assert.strictEqual(safetensorsShardFileInfo?.total, "00072");
});
});
19 changes: 19 additions & 0 deletions packages/hub/src/lib/parse-safetensors-metadata.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,25 @@ export const RE_SAFETENSORS_FILE = /\.safetensors$/;
export const RE_SAFETENSORS_INDEX_FILE = /\.safetensors\.index\.json$/;
export const RE_SAFETENSORS_SHARD_FILE =
/^(?<prefix>(?<basePrefix>.*?)[_-])(?<shard>\d{5})-of-(?<total>\d{5})\.safetensors$/;
export interface SafetensorsShardFileInfo {
prefix: string;
basePrefix: string;
shard: string;
total: string;
}
export function parseSafetensorsShardFile(filename: string): SafetensorsShardFileInfo | null {
const match = RE_SAFETENSORS_SHARD_FILE.exec(filename);
if (match && match.groups) {
return {
prefix: match.groups["prefix"],
basePrefix: match.groups["basePrefix"],
shard: match.groups["shard"],
total: match.groups["total"],
};
}
return null;
}

const PARALLEL_DOWNLOADS = 20;
const MAX_HEADER_LENGTH = 25_000_000;

Expand Down

0 comments on commit c0a43bc

Please sign in to comment.