diff --git a/llm/download.py b/llm/download.py index f9b619e..2472804 100644 --- a/llm/download.py +++ b/llm/download.py @@ -13,6 +13,7 @@ from collections import Counter import re import uuid +from typing import List import huggingface_hub as hfh from huggingface_hub.utils import HfHubHTTPError from utils.marsgen import get_mar_name, generate_mars @@ -38,7 +39,7 @@ MODEL_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "model_config.json") -def get_ignore_pattern_list(extension_list): +def get_ignore_pattern_list(extension_list: List[str]) -> List[str]: """ This function takes a list of file extensions and returns a list of patterns that can be used to filter out files with these extensions during model download. @@ -53,7 +54,7 @@ def get_ignore_pattern_list(extension_list): return ["*" + pattern for pattern in extension_list] -def compare_lists(list1, list2): +def compare_lists(list1: List[str], list2: List[str]) -> bool: """ This function checks if two lists are equal by comparing their contents, regardless of the order. @@ -68,7 +69,9 @@ def compare_lists(list1, list2): return Counter(list1) == Counter(list2) -def filter_files_by_extension(filenames, extensions_to_remove): +def filter_files_by_extension( + filenames: List[str], extensions_to_remove: List[str] +) -> List[str]: """ This function takes a list of filenames and a list of extensions to remove. It returns a new list of filenames after filtering out those with specified extensions. @@ -89,7 +92,7 @@ def filter_files_by_extension(filenames, extensions_to_remove): return filtered_filenames -def check_if_model_files_exist(gen_model: GenerateDataModel): +def check_if_model_files_exist(gen_model: GenerateDataModel) -> bool: """ This function compares the list of files in the downloaded model directory with the list of files in the HuggingFace repository. It takes into account any files to @@ -113,7 +116,7 @@ def check_if_model_files_exist(gen_model: GenerateDataModel): return compare_lists(extra_files_list, repo_files) -def create_tmp_model_store(mar_output, mar_name): +def create_tmp_model_store(mar_output: str, mar_name: str) -> str: """ This function creates a temporary directory in model store in which the MAR file will be stored temporarily. @@ -132,7 +135,7 @@ def create_tmp_model_store(mar_output, mar_name): return tmp_dir -def move_mar(gen_model: GenerateDataModel, tmp_dir): +def move_mar(gen_model: GenerateDataModel, tmp_dir: str) -> None: """ This funtion moves MAR file from the temporary directory to model store. @@ -148,7 +151,7 @@ def move_mar(gen_model: GenerateDataModel, tmp_dir): mv_file(src, dst) -def read_config_for_download(gen_model: GenerateDataModel): +def read_config_for_download(gen_model: GenerateDataModel) -> GenerateDataModel: """ This function reads repo id, version and handler name from model_config.json and sets values for the GenerateDataModel object. @@ -245,7 +248,7 @@ def read_config_for_download(gen_model: GenerateDataModel): return gen_model -def run_download(gen_model: GenerateDataModel): +def run_download(gen_model: GenerateDataModel) -> GenerateDataModel: """ This function checks if the given model path directory is empty and then downloads the given version's model files at that path. @@ -277,7 +280,7 @@ def run_download(gen_model: GenerateDataModel): return gen_model -def create_mar(gen_model: GenerateDataModel): +def create_mar(gen_model: GenerateDataModel) -> None: """ This function checks if the Model Archive (MAR) file for the downloaded model exists in the specified model path otherwise generates the MAR file. @@ -313,7 +316,7 @@ def create_mar(gen_model: GenerateDataModel): ) -def run_script(params): +def run_script(params: argparse.Namespace) -> bool: """ This function validates input parameters, downloads model files and creates model archive file (MAR file) for the given model. @@ -321,6 +324,8 @@ def run_script(params): Args: params (Namespace): An argparse.Namespace object containing command-line arguments. These are the necessary parameters and configurations for the script. + Returns: + bool: True for successful execution and False otherwise (used for testing) """ gen_model = GenerateDataModel(params) gen_model = read_config_for_download(gen_model) diff --git a/llm/tests/test_download.py b/llm/tests/test_download.py index 04c06dd..69f99bc 100644 --- a/llm/tests/test_download.py +++ b/llm/tests/test_download.py @@ -26,7 +26,7 @@ MODEL_TEMP_CONFIG_PATH = os.path.join(MODEL_STORE, "temp_model_config.json") -def rm_dir(path): +def rm_dir(path: str) -> None: """ This function deletes a directory. @@ -39,7 +39,7 @@ def rm_dir(path): shutil.rmtree(path) -def download_setup(): +def download_setup() -> None: """ This function deletes and creates model store and model path directories. @@ -50,7 +50,7 @@ def download_setup(): os.makedirs(MODEL_STORE) -def cleanup_folders(): +def cleanup_folders() -> None: """ This function deletes model store and model path directories. """ @@ -59,12 +59,12 @@ def cleanup_folders(): def set_generate_args( - model_name=MODEL_NAME, - repo_version="", - model_path=MODEL_PATH, - mar_output=MODEL_STORE, - handler_path="", -): + model_name: str = MODEL_NAME, + repo_version: str = "", + model_path: str = MODEL_PATH, + mar_output: str = MODEL_STORE, + handler_path: str = "", +) -> argparse.Namespace: """ This function sets the arguments to run download.py. @@ -89,7 +89,7 @@ def set_generate_args( return args -def test_default_generate_success(): +def test_default_generate_success() -> None: """ This function tests the default GPT2 model. Expected result: Success. @@ -104,7 +104,7 @@ def test_default_generate_success(): assert result is True -def test_wrong_model_store_throw_error(): +def test_wrong_model_store_throw_error() -> None: """ This function tests wrong model store path. Expected result: Failure. @@ -119,7 +119,7 @@ def test_wrong_model_store_throw_error(): assert False -def test_wrong_model_path_throw_error(): +def test_wrong_model_path_throw_error() -> None: """ This function tests wrong model files path. Expected result: Failure. @@ -134,7 +134,7 @@ def test_wrong_model_path_throw_error(): assert False -def test_non_empty_model_path_throw_error(): +def test_non_empty_model_path_throw_error() -> None: """ This function tests non empty model files path without skip download. Expected result: Failure. @@ -151,7 +151,7 @@ def test_non_empty_model_path_throw_error(): assert False -def test_invalid_repo_version_throw_error(): +def test_invalid_repo_version_throw_error() -> None: """ This function tests invalid repo version. Expected result: Failure. @@ -166,7 +166,7 @@ def test_invalid_repo_version_throw_error(): assert False -def test_valid_repo_version_success(): +def test_valid_repo_version_success() -> None: """ This function tests valid repo version. Expected result: Success. @@ -181,7 +181,7 @@ def test_valid_repo_version_success(): assert result is True -def test_invalid_handler_throw_error(): +def test_invalid_handler_throw_error() -> None: """ This function tests invalid handler path. Expected result: Failure. @@ -196,7 +196,7 @@ def test_invalid_handler_throw_error(): assert False -def test_skip_download_throw_error(): +def test_skip_download_throw_error() -> None: """ This function tests skip download without model files. Expected result: Failure. @@ -212,7 +212,7 @@ def test_skip_download_throw_error(): assert False -def test_mar_exists_throw_error(): +def test_mar_exists_throw_error() -> None: """ This function tests if MAR file already exists. Expected result: Exits. @@ -228,7 +228,7 @@ def test_mar_exists_throw_error(): assert False -def test_skip_download_success(): +def test_skip_download_success() -> None: """ This function tests skip download case. Expected result: Success. @@ -250,7 +250,7 @@ def test_skip_download_success(): assert result is True -def custom_model_setup(): +def custom_model_setup() -> None: """ This function is used to setup custom model case. It runs download.py to download model files and @@ -267,7 +267,7 @@ def custom_model_setup(): json.dump({}, file) -def custom_model_restore(): +def custom_model_restore() -> None: """ This function restores the 'model_config.json' file and runs cleanup_folders function. @@ -277,7 +277,7 @@ def custom_model_restore(): cleanup_folders() -def test_custom_model_success(): +def test_custom_model_success() -> None: """ This function tests the custom model case. This is done by clearing the 'model_config.json' and diff --git a/llm/tests/test_torchserve_run.py b/llm/tests/test_torchserve_run.py index 6b0a511..97470e7 100644 --- a/llm/tests/test_torchserve_run.py +++ b/llm/tests/test_torchserve_run.py @@ -6,6 +6,7 @@ """ import os import subprocess +from typing import List import pytest import download from tests.test_download import ( @@ -22,7 +23,7 @@ ) -def test_generate_mar_success(): +def test_generate_mar_success() -> None: """ This function calls the default testcase from test_download.py This is done to generate the MAR file used in the rest of the @@ -32,8 +33,11 @@ def test_generate_mar_success(): def get_run_cmd( - model_name=MODEL_NAME, model_store=MODEL_STORE, input_path="", repo_version="" -): + model_name: str = MODEL_NAME, + model_store: str = MODEL_STORE, + input_path: str = "", + repo_version: str = "", +) -> List[str]: """ This function is used to generate the bash command to be run using given parameters @@ -59,7 +63,7 @@ def get_run_cmd( return cmd.split() -def test_default_success(): +def test_default_success() -> None: """ This function tests the default GPT2 model with input path. Expected result: Success. @@ -68,7 +72,7 @@ def test_default_success(): assert process.returncode == 0 -def test_default_no_input_path_success(): +def test_default_no_input_path_success() -> None: """ This function tests the default GPT2 model without input path. Expected result: Success. @@ -77,7 +81,7 @@ def test_default_no_input_path_success(): assert process.returncode == 0 -def test_no_model_name_throw_error(): +def test_no_model_name_throw_error() -> None: """ This function tests missing model name. Expected result: Failure. @@ -86,7 +90,7 @@ def test_no_model_name_throw_error(): assert process.returncode == 1 -def test_wrong_model_name_throw_error(): +def test_wrong_model_name_throw_error() -> None: """ This function tests wrong model name. Expected result: Failure. @@ -95,7 +99,7 @@ def test_wrong_model_name_throw_error(): assert process.returncode == 1 -def test_no_model_store_throw_error(): +def test_no_model_store_throw_error() -> None: """ This function tests missing model store. Expected result: Failure. @@ -104,7 +108,7 @@ def test_no_model_store_throw_error(): assert process.returncode == 1 -def test_wrong_model_store_throw_error(): +def test_wrong_model_store_throw_error() -> None: """ This function tests wrong model store. Expected result: Failure. @@ -113,7 +117,7 @@ def test_wrong_model_store_throw_error(): assert process.returncode == 1 -def test_wrong_input_path_throw_error(): +def test_wrong_input_path_throw_error() -> None: """ This function tests wrong input path. Expected result: Failure. @@ -122,7 +126,7 @@ def test_wrong_input_path_throw_error(): assert process.returncode == 1 -def test_vaild_repo_version_success(): +def test_vaild_repo_version_success() -> None: """ This function tests valid repo version. Expected result: Success. @@ -134,7 +138,7 @@ def test_vaild_repo_version_success(): assert process.returncode == 0 -def test_invalid_repo_version_throw_error(): +def test_invalid_repo_version_throw_error() -> None: """ This function tests invalid repo version. Expected result: Failure. @@ -145,7 +149,7 @@ def test_invalid_repo_version_throw_error(): assert process.returncode == 1 -def test_custom_model_success(): +def test_custom_model_success() -> None: """ This function tests custom model with input folder. Expected result: Success. diff --git a/llm/torchserve_run.py b/llm/torchserve_run.py index 741a3a1..8daab92 100644 --- a/llm/torchserve_run.py +++ b/llm/torchserve_run.py @@ -19,7 +19,7 @@ MODEL_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "model_config.json") -def read_config_for_inference(params): +def read_config_for_inference(params: argparse.Namespace) -> argparse.Namespace: """ Function that reads repo version and validates GPU type @@ -59,7 +59,9 @@ def read_config_for_inference(params): return params -def set_mar_filepath(model_store, model_name, repo_version, is_custom_model): +def set_mar_filepath( + model_store: str, model_name: str, repo_version: str, is_custom_model: bool +) -> str: """ Funtion that creates the MAR file path given the model store, model name and repo version. The name of the MAR file is returned from get_mar_name from marsgen. @@ -78,7 +80,7 @@ def set_mar_filepath(model_store, model_name, repo_version, is_custom_model): return os.path.join(model_store, mar_name) -def run_inference_with_mar(params): +def run_inference_with_mar(params: str) -> None: """ Function that checks sets the required parameters, starts Torchserve, registers the model and runs inference on given input data. @@ -91,7 +93,7 @@ def run_inference_with_mar(params): get_inference(data_model, params.debug_mode) -def run_inference(params): +def run_inference(params: argparse.Namespace) -> None: """ This function validates model store directory, MAR file path, input data directory, generates the temporary gen folder to store logs and sets model generation parameters as @@ -122,7 +124,7 @@ def run_inference(params): run_inference_with_mar(params) -def torchserve_run(params): +def torchserve_run(params: argparse.Namespace) -> None: """ This function calls cleanup function, check if model config exists and then calls run_inference. @@ -147,7 +149,7 @@ def torchserve_run(params): cleanup(params.gen_folder_name, params.stop_server, params.ts_cleanup) -def cleanup(gen_folder, ts_stop=True, ts_cleanup=True): +def cleanup(gen_folder: str, ts_stop: bool = True, ts_cleanup: bool = True) -> None: """ This function stops Torchserve, deletes the temporary gen folder and the logs in it. diff --git a/llm/utils/generate_data_model.py b/llm/utils/generate_data_model.py index 124bba7..3a951d7 100644 --- a/llm/utils/generate_data_model.py +++ b/llm/utils/generate_data_model.py @@ -3,6 +3,7 @@ and function set_values that sets the GenerateDataModel attributes. """ +import argparse import os import dataclasses import sys @@ -63,7 +64,7 @@ class GenerateDataModel: repo_info = RepoInfo() debug = bool() - def __init__(self, params): + def __init__(self, params: argparse.Namespace) -> None: """ This is the init method that calls set_values method. @@ -72,7 +73,7 @@ def __init__(self, params): """ self.set_values(params) - def set_values(self, params): + def set_values(self, params: argparse.Namespace) -> None: """ This method sets values for the GenerateDataModel object based on the command-line arguments. @@ -91,7 +92,7 @@ def set_values(self, params): self.mar_utils.model_path = params.model_path self.mar_utils.mar_output = params.mar_output - def check_if_mar_exists(self): + def check_if_mar_exists(self) -> None: """ This method checks if MAR file of a model already exists and skips generation if the MAR file already exists diff --git a/llm/utils/inference_data_model.py b/llm/utils/inference_data_model.py index 4a13914..f6ac196 100644 --- a/llm/utils/inference_data_model.py +++ b/llm/utils/inference_data_model.py @@ -2,6 +2,7 @@ This module stores the dataclasses InferenceDataModel, TorchserveStartData and function prepare_settings to set the InferenceDataModel's ts_data. """ +import argparse import os import dataclasses @@ -46,7 +47,7 @@ class InferenceDataModel: mar_filepath = str() ts_data = TorchserveStartData() - def __init__(self, params): + def __init__(self, params: argparse.Namespace) -> None: """ This is the init method that calls set_data_model method. @@ -55,7 +56,7 @@ def __init__(self, params): """ self.set_data_model(params) - def set_data_model(self, args): + def set_data_model(self, args: argparse.Namespace) -> None: """ This method sets model_name, input_path, gen_folder, mar_filepath, repo_version attributes of the InferenceDataModel class. @@ -69,7 +70,7 @@ def set_data_model(self, args): self.mar_filepath = args.mar self.repo_version = args.repo_version - def prepare_settings(self): + def prepare_settings(self) -> None: """ This method sets ts_data attribute of InferenceDataModel class, sets environment variables LOG_LOCATION and METRICS_LOCATION and makes gen folder. diff --git a/llm/utils/inference_utils.py b/llm/utils/inference_utils.py index c758446..0e5030c 100644 --- a/llm/utils/inference_utils.py +++ b/llm/utils/inference_utils.py @@ -5,6 +5,7 @@ import sys import time import traceback +from typing import List, Dict import requests import utils.tsutils as ts import utils.system_utils as su @@ -14,7 +15,7 @@ ) -def error_msg_print(): +def error_msg_print() -> None: """ This function prints an error message and stops Torchserve. """ @@ -24,7 +25,7 @@ def error_msg_print(): ts.stop_torchserve() -def start_ts_server(ts_data: TorchserveStartData, debug): +def start_ts_server(ts_data: TorchserveStartData, debug: bool) -> None: """ This function starts Torchserve by calling start_torchserve from tsutils and throws error if it doesn't start. @@ -39,7 +40,7 @@ def start_ts_server(ts_data: TorchserveStartData, debug): sys.exit(1) -def ts_health_check(model_name, model_timeout=1200): +def ts_health_check(model_name: str, model_timeout: int = 1200) -> None: """ This function checks if the model is registered or not. Args: @@ -70,7 +71,7 @@ def ts_health_check(model_name, model_timeout=1200): sys.exit(1) -def execute_inference_on_inputs(model_inputs, model_name): +def execute_inference_on_inputs(model_inputs: List[str], model_name: str) -> None: """ This function runs inference on given input data files and model name by calling run_inference from tsutils. @@ -93,7 +94,7 @@ def execute_inference_on_inputs(model_inputs, model_name): sys.exit(1) -def validate_inference_model(models_to_validate, debug): +def validate_inference_model(models_to_validate: List[Dict], debug: bool) -> None: """ This function consolidates model name and input to use for inference and calls execute_inference_on_inputs @@ -115,7 +116,7 @@ def validate_inference_model(models_to_validate, debug): print(f"## {model_name} Handler is stable. \n") -def get_inference(data_model: InferenceDataModel, debug): +def get_inference(data_model: InferenceDataModel, debug: bool) -> None: """ This function starts Torchserve, runs health check of server, registers model, and runs inference on input folder path. It catches KeyError and HTTPError exceptions diff --git a/llm/utils/marsgen.py b/llm/utils/marsgen.py index 549695e..d62376b 100644 --- a/llm/utils/marsgen.py +++ b/llm/utils/marsgen.py @@ -7,6 +7,7 @@ import os import sys import subprocess +from typing import Dict from utils.system_utils import check_if_path_exists, get_all_files_in_directory from utils.generate_data_model import GenerateDataModel @@ -14,7 +15,9 @@ MAR_NAME_LEN = 7 -def get_mar_name(model_name, repo_version, is_custom_model=False): +def get_mar_name( + model_name: str, repo_version: str, is_custom_model: str = False +) -> str: """ This function returns MAR file name using model name and repo version. @@ -35,8 +38,11 @@ def get_mar_name(model_name, repo_version, is_custom_model=False): def generate_mars( - gen_model: GenerateDataModel, mar_config, model_store_dir, debug=False -): + gen_model: GenerateDataModel, + mar_config: str, + model_store_dir: str, + debug: str = False, +) -> None: """ This function runs Torch Model Archiver command to generate MAR file. It calls the model_archiver_command_builder function to generate the command which it then runs @@ -99,12 +105,15 @@ def generate_mars( os.chdir(cwd) -def model_archiver_command_builder(model_archiver_args, debug=False): +def model_archiver_command_builder( + model_archiver_args: Dict[str, str], debug: bool = False +) -> str: """ This function makes the Torch Model Archiver command using model_archiver_args parameter. Args: - model_archiver_args (dict): Contains + model_archiver_args (dict): Contains dictionary of arguments required to generate + torch model archiever command debug (bool, optional): Flag to print debug statements. Defaults to False. Returns: diff --git a/llm/utils/shell_utils.py b/llm/utils/shell_utils.py index 7af5516..d8fa8bd 100644 --- a/llm/utils/shell_utils.py +++ b/llm/utils/shell_utils.py @@ -8,7 +8,7 @@ from pathlib import Path -def rm_file(path, regex=False): +def rm_file(path: str, regex: bool = False) -> None: """ This function deletes file or files in a path recursively. @@ -27,7 +27,7 @@ def rm_file(path, regex=False): os.remove(path) -def rm_dir(path): +def rm_dir(path: str) -> None: """ This function deletes a directory. @@ -40,7 +40,7 @@ def rm_dir(path): shutil.rmtree(path) -def mv_file(src, dst): +def mv_file(src: str, dst: str) -> None: """ This function moves a file from src to dst. @@ -51,7 +51,7 @@ def mv_file(src, dst): shutil.move(src, dst) -def copy_file(source_file, destination_file): +def copy_file(source_file: str, destination_file: str) -> None: """ This function copies a file from source file path to destination file path Args: @@ -59,8 +59,6 @@ def copy_file(source_file, destination_file): destination_file (str): The path where the file is to be copied. Raises: Exception: If any error occurs during copying file. - Returns: - None """ try: shutil.copy(source_file, destination_file) diff --git a/llm/utils/system_utils.py b/llm/utils/system_utils.py index aa73d33..e2c9270 100644 --- a/llm/utils/system_utils.py +++ b/llm/utils/system_utils.py @@ -15,7 +15,7 @@ } -def check_if_path_exists(filepath, err="", is_dir=False): +def check_if_path_exists(filepath: str, err: str = "", is_dir: bool = False) -> None: """ This function checks if a given path exists. @@ -31,7 +31,7 @@ def check_if_path_exists(filepath, err="", is_dir=False): sys.exit(1) -def create_folder_if_not_exists(path): +def create_folder_if_not_exists(path: str) -> None: """ This function creates a dirctory if it doesn't already exist. @@ -42,7 +42,7 @@ def create_folder_if_not_exists(path): print(f"The new directory is created! - {path}") -def check_if_folder_empty(path): +def check_if_folder_empty(path: str) -> bool: """ This function checks if a directory is empty. @@ -56,7 +56,7 @@ def check_if_folder_empty(path): return len(dir_items) == 0 -def remove_suffix_if_starts_with(string, suffix): +def remove_suffix_if_starts_with(string: str, suffix: str) -> str: """ This function removes a suffix of a string is it starts with a given suffix diff --git a/llm/utils/tsutils.py b/llm/utils/tsutils.py index 002ffc7..cc67325 100644 --- a/llm/utils/tsutils.py +++ b/llm/utils/tsutils.py @@ -11,6 +11,7 @@ import platform import time import json +from typing import Tuple, Dict import requests from utils.inference_data_model import InferenceDataModel, TorchserveStartData from utils.system_utils import check_if_path_exists @@ -30,7 +31,7 @@ } -def generate_ts_start_cmd(ncs, ts_data: TorchserveStartData, debug): +def generate_ts_start_cmd(ts_data: TorchserveStartData, ncs: bool, debug: bool) -> str: """ This function generates the Torchserve start command. @@ -59,7 +60,12 @@ def generate_ts_start_cmd(ncs, ts_data: TorchserveStartData, debug): return cmd -def start_torchserve(ts_data: TorchserveStartData, ncs=True, wait_for=10, debug=False): +def start_torchserve( + ts_data: TorchserveStartData, + ncs: bool = True, + wait_for: int = 10, + debug: bool = False, +) -> bool: """ This function calls generate_ts_start_cmd function to get the Torchserve start command and runs the same to start Torchserve. @@ -74,7 +80,7 @@ def start_torchserve(ts_data: TorchserveStartData, ncs=True, wait_for=10, debug= bool: True for successful Torchserve start and False otherwise """ print("\n## Starting TorchServe \n") - cmd = generate_ts_start_cmd(ncs, ts_data, debug) + cmd = generate_ts_start_cmd(ts_data, ncs, debug) if debug: print(cmd) status = os.system(cmd) @@ -87,7 +93,7 @@ def start_torchserve(ts_data: TorchserveStartData, ncs=True, wait_for=10, debug= return False -def stop_torchserve(wait_for=10): +def stop_torchserve(wait_for: int = 10) -> bool: """ This function is used to stop Torchserve. @@ -113,7 +119,7 @@ def stop_torchserve(wait_for=10): return False -def set_config_properties(data_model: InferenceDataModel): +def set_config_properties(data_model: InferenceDataModel) -> None: """ This function creates a configuration file for the model and sets certain parameters. Args: @@ -157,7 +163,7 @@ class with relevant information. data_model.ts_data.ts_config_file = dst_config_path -def set_model_params(model_name): +def set_model_params(model_name: str) -> None: """ This function reads generation parameters from model_config.json and sets them as environment variables for the handler to read. The generation parameters are : @@ -192,7 +198,7 @@ def set_model_params(model_name): del os.environ[param_name] -def get_params_for_registration(model_name): +def get_params_for_registration(model_name: str) -> Tuple[str, str, str, str]: """ This function reads registration parameters from model_config.json returns them. The generation parameters are : @@ -202,7 +208,8 @@ def get_params_for_registration(model_name): model_name (str): Name of the model. Returns: - str: initial_workers, batch_size, max_batch_delay, response_timeout + Tuple[str, str, str, str]: initial_workers, batch_size, max_batch_delay, + response_timeout """ dirpath = os.path.dirname(__file__) initial_workers = batch_size = max_batch_delay = response_timeout = None @@ -229,8 +236,12 @@ def get_params_for_registration(model_name): def run_inference( - model_inference_data, protocol="http", host="localhost", port="8080", timeout=120 -): + model_inference_data: Dict, + protocol: str = "http", + host: str = "localhost", + port: str = "8080", + timeout: int = 120, +) -> requests.Response: """ This function sends request to run inference on Torchserve. @@ -255,8 +266,12 @@ def run_inference( def run_health_check( - model_name, protocol="http", host="localhost", port="8081", timeout=120 -): + model_name: str, + protocol: str = "http", + host: str = "localhost", + port: str = "8081", + timeout: int = 120, +) -> bool: """ This function runs a health check for the workers of the deployed model