Skip to content

Commit

Permalink
Merge pull request #41 from saileshd1402/pin-setuptools
Browse files Browse the repository at this point in the history
Pin requirements versions and update download logic
  • Loading branch information
johnugeorge authored Jun 17, 2024
2 parents 0fec63a + 8fce01a commit e2ecb4d
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 87 deletions.
119 changes: 35 additions & 84 deletions llm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import argparse
import json
import sys
from collections import Counter
import re
import uuid
from typing import List
import huggingface_hub as hfh
Expand All @@ -21,99 +19,48 @@
check_if_path_exists,
check_if_folder_empty,
create_folder_if_not_exists,
get_all_files_in_directory,
)
from utils.shell_utils import mv_file, rm_dir
from utils.generate_data_model import GenerateDataModel

FILE_EXTENSIONS_TO_IGNORE = [
".safetensors",
".safetensors.index.json",
".h5",
".ot",
".tflite",
".msgpack",
".onnx",
PREFERRED_MODEL_FORMATS = [".safetensors", ".bin"] # In order of Preference
OTHER_MODEL_FORMATS = [
"*.pt",
"*.h5",
"*.gguf",
"*.msgpack",
"*.tflite",
"*.ot",
"*.onnx",
]

MODEL_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "model_config.json")


def get_ignore_pattern_list(extension_list: List[str]) -> List[str]:
def get_ignore_pattern_list(gen_model: GenerateDataModel) -> List[str]:
"""
This function takes a list of file extensions and returns a list of patterns that
can be used to filter out files with these extensions during model download.
This method creates a list of file extensions to ignore from a priority list based on files
present in the Hugging Face Repo. It filters out extensions not found in the repository and
returns them as ignore patterns prefixed with '*' which is expected by Hugging Face client.
Args:
extension_list (list(str)): A list of file extensions.
gen_model (GenerateDataModel): An instance of the GenerateDataModel class
Returns:
list(str): A list of patterns with '*' prepended to each extension,
suitable for filtering files.
"""
return ["*" + pattern for pattern in extension_list]


def compare_lists(list1: List[str], list2: List[str]) -> bool:
"""
This function checks if two lists are equal by comparing their contents,
regardless of the order.
Args:
list1 (list(str)): The first list to compare.
list2 (list(str)): The second list to compare.
Returns:
bool: True if the lists have the same elements, False otherwise.
"""
return Counter(list1) == Counter(list2)


def filter_files_by_extension(
filenames: List[str], extensions_to_remove: List[str]
) -> List[str]:
"""
This function takes a list of filenames and a list of extensions to remove.
It returns a new list of filenames after filtering out those with specified extensions.
It uses regex patterns to filter filenames
Args:
filenames (list(str)): A list of filenames to be filtered.
extensions_to_remove (list(str)): A list of file extensions to remove.
Returns:
list(str): A list of filenames after filtering.
"""
pattern = "|".join([re.escape(suffix) + "$" for suffix in extensions_to_remove])
# for FILE_EXTENSIONS_TO_IGNORE the pattern will be '\.safetensors$|\.safetensors\.index\.json$'
filtered_filenames = [
filename for filename in filenames if not re.search(pattern, filename)
]
return filtered_filenames


def check_if_model_files_exist(gen_model: GenerateDataModel) -> bool:
"""
This function compares the list of files in the downloaded model directory with the
list of files in the HuggingFace repository. It takes into account any files to
ignore based on predefined extensions.
Args:
gen_model (GenerateDataModel): An instance of the GenerateDataModel dataclass
Returns:
bool: True if the downloaded model files match the expected
repository files, False otherwise.
"""
extra_files_list = get_all_files_in_directory(gen_model.mar_utils.model_path)
hf_api = hfh.HfApi()
repo_files = hf_api.list_repo_files(
repo_id=gen_model.repo_info.repo_id,
revision=gen_model.repo_info.repo_version,
token=gen_model.repo_info.hf_token,
)
repo_files = filter_files_by_extension(repo_files, FILE_EXTENSIONS_TO_IGNORE)
return compare_lists(extra_files_list, repo_files)
repo_file_extensions = gen_model.get_repo_file_extensions()
for desired_extension in PREFERRED_MODEL_FORMATS:
if desired_extension in repo_file_extensions:
ignore_list = [
"*" + ignore_extension
for ignore_extension in PREFERRED_MODEL_FORMATS
if ignore_extension != desired_extension
]
ignore_list.extend(OTHER_MODEL_FORMATS)
return ignore_list
return []


def create_tmp_model_store(mar_output: str, mar_name: str) -> str:
Expand Down Expand Up @@ -267,14 +214,20 @@ def run_download(gen_model: GenerateDataModel) -> GenerateDataModel:
f" with version {gen_model.repo_info.repo_version}\n"
)

tmp_hf_cache = os.path.join(gen_model.mar_utils.model_path, "tmp_hf_cache")
create_folder_if_not_exists(tmp_hf_cache)

hfh.snapshot_download(
repo_id=gen_model.repo_info.repo_id,
revision=gen_model.repo_info.repo_version,
local_dir=gen_model.mar_utils.model_path,
local_dir_use_symlinks=False,
token=gen_model.repo_info.hf_token,
ignore_patterns=get_ignore_pattern_list(FILE_EXTENSIONS_TO_IGNORE),
local_dir_use_symlinks=False,
cache_dir=tmp_hf_cache,
force_download=True,
ignore_patterns=get_ignore_pattern_list(gen_model),
)
rm_dir(tmp_hf_cache)
print("## Successfully downloaded model_files\n")
return gen_model

Expand All @@ -289,10 +242,8 @@ def create_mar(gen_model: GenerateDataModel) -> None:
Args:
gen_model (GenerateDataModel): An instance of the GenerateDataModel dataclass
"""
if not (
gen_model.is_custom_model and gen_model.skip_download
) and not check_if_model_files_exist(gen_model):
print("## Model files do not match HuggingFace repository files")
if check_if_folder_empty(gen_model.mar_utils.model_path):
print("## Model files not present in Model Path directory")
sys.exit(1)

# Creates a temporary directory with the mar_name inside model_store
Expand Down
4 changes: 3 additions & 1 deletion llm/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ fastai==2.7.12
tokenizers==0.15.0
torchdata==0.6.1
transformers== 4.38.1
huggingface-hub==0.22.2
accelerate==0.22.0
nvgpu==0.10.0
torchserve==0.8.2
torch-model-archiver==0.8.1
einops==0.6.1
bitsandbytes==0.41.1
bitsandbytes==0.41.1
setuptools==69.5.1
45 changes: 43 additions & 2 deletions llm/utils/generate_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
import dataclasses
import sys
import huggingface_hub as hfh
from huggingface_hub.utils import HfHubHTTPError, HFValidationError
from huggingface_hub.utils import (
HfHubHTTPError,
HFValidationError,
GatedRepoError,
RepositoryNotFoundError,
RevisionNotFoundError,
)


@dataclasses.dataclass
Expand Down Expand Up @@ -131,7 +137,7 @@ def validate_hf_token(self) -> None:
)
sys.exit(1)

def validate_commit_info(self) -> str:
def validate_commit_info(self) -> None:
"""
This method validates the HuggingFace repository information and
sets the latest commit ID of the model if repo_version is None.
Expand All @@ -154,3 +160,38 @@ def validate_commit_info(self) -> str:
" or HuggingFace ID is not correct\n"
)
sys.exit(1)

def get_repo_file_extensions(self) -> set:
"""
This function returns set of all file extensions in the Hugging Face repo of
the model.
Returns:
repo_file_extension (set): The set of all file extensions in the
Hugging Face repo of the model
Raises:
sys.exit(1): If repo_id, repo_version or huggingface token
is not valid, the function will terminate
the program with an exit code of 1.
"""
try:
hf_api = hfh.HfApi()
repo_files = hf_api.list_repo_files(
repo_id=self.repo_info.repo_id,
revision=self.repo_info.repo_version,
token=self.repo_info.hf_token,
)
return {os.path.splitext(file_name)[1] for file_name in repo_files}
except (
GatedRepoError,
RepositoryNotFoundError,
RevisionNotFoundError,
HfHubHTTPError,
HFValidationError,
ValueError,
KeyError,
):
print(
"## Error: Please check either repo_id, repo_version"
" or HuggingFace ID is not correct\n"
)
sys.exit(1)

0 comments on commit e2ecb4d

Please sign in to comment.