Skip to content

Commit

Permalink
initial type check commit
Browse files Browse the repository at this point in the history
  • Loading branch information
saileshd1402 committed Nov 2, 2023
1 parent d54a129 commit bff1171
Show file tree
Hide file tree
Showing 11 changed files with 126 additions and 90 deletions.
25 changes: 15 additions & 10 deletions llm/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from collections import Counter
import re
import uuid
from typing import List
import huggingface_hub as hfh
from huggingface_hub.utils import HfHubHTTPError
from utils.marsgen import get_mar_name, generate_mars
Expand All @@ -38,7 +39,7 @@
MODEL_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "model_config.json")


def get_ignore_pattern_list(extension_list):
def get_ignore_pattern_list(extension_list: List[str]) -> List[str]:
"""
This function takes a list of file extensions and returns a list of patterns that
can be used to filter out files with these extensions during model download.
Expand All @@ -53,7 +54,7 @@ def get_ignore_pattern_list(extension_list):
return ["*" + pattern for pattern in extension_list]


def compare_lists(list1, list2):
def compare_lists(list1: List[str], list2: List[str]) -> bool:
"""
This function checks if two lists are equal by comparing their contents,
regardless of the order.
Expand All @@ -68,7 +69,9 @@ def compare_lists(list1, list2):
return Counter(list1) == Counter(list2)


def filter_files_by_extension(filenames, extensions_to_remove):
def filter_files_by_extension(
filenames: List[str], extensions_to_remove: List[str]
) -> List[str]:
"""
This function takes a list of filenames and a list of extensions to remove.
It returns a new list of filenames after filtering out those with specified extensions.
Expand All @@ -89,7 +92,7 @@ def filter_files_by_extension(filenames, extensions_to_remove):
return filtered_filenames


def check_if_model_files_exist(gen_model: GenerateDataModel):
def check_if_model_files_exist(gen_model: GenerateDataModel) -> bool:
"""
This function compares the list of files in the downloaded model directory with the
list of files in the HuggingFace repository. It takes into account any files to
Expand All @@ -113,7 +116,7 @@ def check_if_model_files_exist(gen_model: GenerateDataModel):
return compare_lists(extra_files_list, repo_files)


def create_tmp_model_store(mar_output, mar_name):
def create_tmp_model_store(mar_output: str, mar_name: str) -> str:
"""
This function creates a temporary directory in model store in which
the MAR file will be stored temporarily.
Expand All @@ -132,7 +135,7 @@ def create_tmp_model_store(mar_output, mar_name):
return tmp_dir


def move_mar(gen_model: GenerateDataModel, tmp_dir):
def move_mar(gen_model: GenerateDataModel, tmp_dir: str) -> None:
"""
This funtion moves MAR file from the temporary directory to model store.
Expand All @@ -148,7 +151,7 @@ def move_mar(gen_model: GenerateDataModel, tmp_dir):
mv_file(src, dst)


def read_config_for_download(gen_model: GenerateDataModel):
def read_config_for_download(gen_model: GenerateDataModel) -> GenerateDataModel:
"""
This function reads repo id, version and handler name from
model_config.json and sets values for the GenerateDataModel object.
Expand Down Expand Up @@ -245,7 +248,7 @@ def read_config_for_download(gen_model: GenerateDataModel):
return gen_model


def run_download(gen_model: GenerateDataModel):
def run_download(gen_model: GenerateDataModel) -> GenerateDataModel:
"""
This function checks if the given model path directory is empty and then
downloads the given version's model files at that path.
Expand Down Expand Up @@ -277,7 +280,7 @@ def run_download(gen_model: GenerateDataModel):
return gen_model


def create_mar(gen_model: GenerateDataModel):
def create_mar(gen_model: GenerateDataModel) -> None:
"""
This function checks if the Model Archive (MAR) file for the downloaded
model exists in the specified model path otherwise generates the MAR file.
Expand Down Expand Up @@ -313,14 +316,16 @@ def create_mar(gen_model: GenerateDataModel):
)


def run_script(params):
def run_script(params: argparse.Namespace) -> bool:
"""
This function validates input parameters, downloads model files and
creates model archive file (MAR file) for the given model.
Args:
params (Namespace): An argparse.Namespace object containing command-line arguments.
These are the necessary parameters and configurations for the script.
Returns:
bool: True for successful execution and False otherwise (used for testing)
"""
gen_model = GenerateDataModel(params)
gen_model = read_config_for_download(gen_model)
Expand Down
44 changes: 22 additions & 22 deletions llm/tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
MODEL_TEMP_CONFIG_PATH = os.path.join(MODEL_STORE, "temp_model_config.json")


def rm_dir(path):
def rm_dir(path: str) -> None:
"""
This function deletes a directory.
Expand All @@ -39,7 +39,7 @@ def rm_dir(path):
shutil.rmtree(path)


def download_setup():
def download_setup() -> None:
"""
This function deletes and creates model store and
model path directories.
Expand All @@ -50,7 +50,7 @@ def download_setup():
os.makedirs(MODEL_STORE)


def cleanup_folders():
def cleanup_folders() -> None:
"""
This function deletes model store and model path directories.
"""
Expand All @@ -59,12 +59,12 @@ def cleanup_folders():


def set_generate_args(
model_name=MODEL_NAME,
repo_version="",
model_path=MODEL_PATH,
mar_output=MODEL_STORE,
handler_path="",
):
model_name: str = MODEL_NAME,
repo_version: str = "",
model_path: str = MODEL_PATH,
mar_output: str = MODEL_STORE,
handler_path: str = "",
) -> argparse.Namespace:
"""
This function sets the arguments to run download.py.
Expand All @@ -89,7 +89,7 @@ def set_generate_args(
return args


def test_default_generate_success():
def test_default_generate_success() -> None:
"""
This function tests the default GPT2 model.
Expected result: Success.
Expand All @@ -104,7 +104,7 @@ def test_default_generate_success():
assert result is True


def test_wrong_model_store_throw_error():
def test_wrong_model_store_throw_error() -> None:
"""
This function tests wrong model store path.
Expected result: Failure.
Expand All @@ -119,7 +119,7 @@ def test_wrong_model_store_throw_error():
assert False


def test_wrong_model_path_throw_error():
def test_wrong_model_path_throw_error() -> None:
"""
This function tests wrong model files path.
Expected result: Failure.
Expand All @@ -134,7 +134,7 @@ def test_wrong_model_path_throw_error():
assert False


def test_non_empty_model_path_throw_error():
def test_non_empty_model_path_throw_error() -> None:
"""
This function tests non empty model files path without skip download.
Expected result: Failure.
Expand All @@ -151,7 +151,7 @@ def test_non_empty_model_path_throw_error():
assert False


def test_invalid_repo_version_throw_error():
def test_invalid_repo_version_throw_error() -> None:
"""
This function tests invalid repo version.
Expected result: Failure.
Expand All @@ -166,7 +166,7 @@ def test_invalid_repo_version_throw_error():
assert False


def test_valid_repo_version_success():
def test_valid_repo_version_success() -> None:
"""
This function tests valid repo version.
Expected result: Success.
Expand All @@ -181,7 +181,7 @@ def test_valid_repo_version_success():
assert result is True


def test_invalid_handler_throw_error():
def test_invalid_handler_throw_error() -> None:
"""
This function tests invalid handler path.
Expected result: Failure.
Expand All @@ -196,7 +196,7 @@ def test_invalid_handler_throw_error():
assert False


def test_skip_download_throw_error():
def test_skip_download_throw_error() -> None:
"""
This function tests skip download without model files.
Expected result: Failure.
Expand All @@ -212,7 +212,7 @@ def test_skip_download_throw_error():
assert False


def test_mar_exists_throw_error():
def test_mar_exists_throw_error() -> None:
"""
This function tests if MAR file already exists.
Expected result: Exits.
Expand All @@ -228,7 +228,7 @@ def test_mar_exists_throw_error():
assert False


def test_skip_download_success():
def test_skip_download_success() -> None:
"""
This function tests skip download case.
Expected result: Success.
Expand All @@ -250,7 +250,7 @@ def test_skip_download_success():
assert result is True


def custom_model_setup():
def custom_model_setup() -> None:
"""
This function is used to setup custom model case.
It runs download.py to download model files and
Expand All @@ -267,7 +267,7 @@ def custom_model_setup():
json.dump({}, file)


def custom_model_restore():
def custom_model_restore() -> None:
"""
This function restores the 'model_config.json' file
and runs cleanup_folders function.
Expand All @@ -277,7 +277,7 @@ def custom_model_restore():
cleanup_folders()


def test_custom_model_success():
def test_custom_model_success() -> None:
"""
This function tests the custom model case.
This is done by clearing the 'model_config.json' and
Expand Down
30 changes: 17 additions & 13 deletions llm/tests/test_torchserve_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import os
import subprocess
from typing import List
import pytest
import download
from tests.test_download import (
Expand All @@ -22,7 +23,7 @@
)


def test_generate_mar_success():
def test_generate_mar_success() -> None:
"""
This function calls the default testcase from test_download.py
This is done to generate the MAR file used in the rest of the
Expand All @@ -32,8 +33,11 @@ def test_generate_mar_success():


def get_run_cmd(
model_name=MODEL_NAME, model_store=MODEL_STORE, input_path="", repo_version=""
):
model_name: str = MODEL_NAME,
model_store: str = MODEL_STORE,
input_path: str = "",
repo_version: str = "",
) -> List[str]:
"""
This function is used to generate the bash command to be run using given
parameters
Expand All @@ -59,7 +63,7 @@ def get_run_cmd(
return cmd.split()


def test_default_success():
def test_default_success() -> None:
"""
This function tests the default GPT2 model with input path.
Expected result: Success.
Expand All @@ -68,7 +72,7 @@ def test_default_success():
assert process.returncode == 0


def test_default_no_input_path_success():
def test_default_no_input_path_success() -> None:
"""
This function tests the default GPT2 model without input path.
Expected result: Success.
Expand All @@ -77,7 +81,7 @@ def test_default_no_input_path_success():
assert process.returncode == 0


def test_no_model_name_throw_error():
def test_no_model_name_throw_error() -> None:
"""
This function tests missing model name.
Expected result: Failure.
Expand All @@ -86,7 +90,7 @@ def test_no_model_name_throw_error():
assert process.returncode == 1


def test_wrong_model_name_throw_error():
def test_wrong_model_name_throw_error() -> None:
"""
This function tests wrong model name.
Expected result: Failure.
Expand All @@ -95,7 +99,7 @@ def test_wrong_model_name_throw_error():
assert process.returncode == 1


def test_no_model_store_throw_error():
def test_no_model_store_throw_error() -> None:
"""
This function tests missing model store.
Expected result: Failure.
Expand All @@ -104,7 +108,7 @@ def test_no_model_store_throw_error():
assert process.returncode == 1


def test_wrong_model_store_throw_error():
def test_wrong_model_store_throw_error() -> None:
"""
This function tests wrong model store.
Expected result: Failure.
Expand All @@ -113,7 +117,7 @@ def test_wrong_model_store_throw_error():
assert process.returncode == 1


def test_wrong_input_path_throw_error():
def test_wrong_input_path_throw_error() -> None:
"""
This function tests wrong input path.
Expected result: Failure.
Expand All @@ -122,7 +126,7 @@ def test_wrong_input_path_throw_error():
assert process.returncode == 1


def test_vaild_repo_version_success():
def test_vaild_repo_version_success() -> None:
"""
This function tests valid repo version.
Expected result: Success.
Expand All @@ -134,7 +138,7 @@ def test_vaild_repo_version_success():
assert process.returncode == 0


def test_invalid_repo_version_throw_error():
def test_invalid_repo_version_throw_error() -> None:
"""
This function tests invalid repo version.
Expected result: Failure.
Expand All @@ -145,7 +149,7 @@ def test_invalid_repo_version_throw_error():
assert process.returncode == 1


def test_custom_model_success():
def test_custom_model_success() -> None:
"""
This function tests custom model with input folder.
Expected result: Success.
Expand Down
Loading

0 comments on commit bff1171

Please sign in to comment.