From 0fec63a1c53edf429837fb224d152c4d64f5cfce Mon Sep 17 00:00:00 2001 From: Gavrish Prabhu Date: Wed, 28 Feb 2024 15:45:49 +0530 Subject: [PATCH] Refresh request map for every proces (#38) * Refresh request map for every proces * fix lint * Update transformer version --- llm/cleanup.py | 1 + llm/generate.py | 1 + llm/handler.py | 14 +++++++------- llm/requirements.txt | 2 +- llm/tests/test_generate.py | 1 + llm/tests/test_torchserve_run.py | 1 + llm/torchserve_run.py | 1 + llm/utils/inference_data_model.py | 1 + llm/utils/inference_utils.py | 1 + llm/utils/marsgen.py | 1 + llm/utils/shell_utils.py | 1 + llm/utils/system_utils.py | 1 + llm/utils/tsutils.py | 1 + 13 files changed, 19 insertions(+), 8 deletions(-) diff --git a/llm/cleanup.py b/llm/cleanup.py index 7f2c88d..1c0fce3 100644 --- a/llm/cleanup.py +++ b/llm/cleanup.py @@ -5,6 +5,7 @@ Attributes: dirpath (str): Stores parent directory of module """ + import os from utils.shell_utils import rm_dir import utils.tsutils as ts diff --git a/llm/generate.py b/llm/generate.py index 1658524..8932f71 100644 --- a/llm/generate.py +++ b/llm/generate.py @@ -6,6 +6,7 @@ during download and validation of model files. MAR_CONFIG_PATH (str): Path of model_config.json. """ + import os import argparse import json diff --git a/llm/handler.py b/llm/handler.py index a2ff8c5..804ad23 100644 --- a/llm/handler.py +++ b/llm/handler.py @@ -3,6 +3,7 @@ The handler provides functions to preprocess input data, make predictions using the model, and post-process the output for a particular use case. """ + import logging import os from abc import ABC @@ -70,16 +71,12 @@ class LLMHandler(BaseHandler, ABC): def __init__(self): super().__init__() self.initialized = False - self.request = { - "request_list": defaultdict(int), - "request_ids": defaultdict(int), - "request_type": defaultdict(int), - } self.tokenizer = None self.map_location = None self.device = None self.model = None self.device_map = None + self.request = None def initialize(self, context): """ @@ -147,6 +144,11 @@ def preprocess(self, data: str) -> torch.Tensor: Tensor: Tokenized input data """ input_list = [] + self.request = { + "request_list": defaultdict(int), + "request_ids": defaultdict(int), + "request_type": defaultdict(int), + } for idx, input_data in enumerate(data): # Pre-process for Kserve v2 format @@ -175,7 +177,6 @@ def preprocess(self, data: str) -> torch.Tensor: self.request["request_type"][idx] = "raw" input_list.append(row_input) - logger.info("Received text: %s", ", ".join(map(str, input_list))) encoded_input = self.tokenizer(input_list, padding=True, return_tensors="pt")[ "input_ids" ].to(self.device) @@ -218,7 +219,6 @@ def inference(self, data: torch.Tensor, *args, **kwargs) -> List[str]: inference = [] inference = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - logger.info("Generated text is: %s", ", ".join(map(str, inference))) return inference def postprocess(self, data: List[str]) -> List[str]: diff --git a/llm/requirements.txt b/llm/requirements.txt index ec50fcc..e4e4a1c 100644 --- a/llm/requirements.txt +++ b/llm/requirements.txt @@ -4,7 +4,7 @@ torchtext==0.15.2 fastai==2.7.12 tokenizers==0.15.0 torchdata==0.6.1 -transformers== 4.36.0 +transformers== 4.38.1 accelerate==0.22.0 nvgpu==0.10.0 torchserve==0.8.2 diff --git a/llm/tests/test_generate.py b/llm/tests/test_generate.py index e02e9c2..0d80aea 100644 --- a/llm/tests/test_generate.py +++ b/llm/tests/test_generate.py @@ -8,6 +8,7 @@ MODEL_CONFIG_PATH: Path to model_config.json file. MODEL_TEMP_CONFIG_PATH: Path to backup model_config.json file. """ + import os import argparse import shutil diff --git a/llm/tests/test_torchserve_run.py b/llm/tests/test_torchserve_run.py index 36cb494..8283be5 100644 --- a/llm/tests/test_torchserve_run.py +++ b/llm/tests/test_torchserve_run.py @@ -4,6 +4,7 @@ Attributes: INPUT_PATH: Path to input data folder. """ + import os import subprocess from typing import List diff --git a/llm/torchserve_run.py b/llm/torchserve_run.py index 66c7d8f..aca005a 100644 --- a/llm/torchserve_run.py +++ b/llm/torchserve_run.py @@ -5,6 +5,7 @@ Attributes: MODEL_CONFIG_PATH (str): Path to model_config.json file. """ + import os import argparse import json diff --git a/llm/utils/inference_data_model.py b/llm/utils/inference_data_model.py index f6ac196..83eb5a6 100644 --- a/llm/utils/inference_data_model.py +++ b/llm/utils/inference_data_model.py @@ -2,6 +2,7 @@ This module stores the dataclasses InferenceDataModel, TorchserveStartData and function prepare_settings to set the InferenceDataModel's ts_data. """ + import argparse import os import dataclasses diff --git a/llm/utils/inference_utils.py b/llm/utils/inference_utils.py index eb35910..c54d0ac 100644 --- a/llm/utils/inference_utils.py +++ b/llm/utils/inference_utils.py @@ -1,6 +1,7 @@ """ This module contains utilities to start and manage Torchserve server. """ + import os import sys import time diff --git a/llm/utils/marsgen.py b/llm/utils/marsgen.py index a859d02..e37611c 100644 --- a/llm/utils/marsgen.py +++ b/llm/utils/marsgen.py @@ -4,6 +4,7 @@ Attributes: MAR_NAME_LEN (int): Number of characters to include from repo_version in MAR name """ + import os import sys import time diff --git a/llm/utils/shell_utils.py b/llm/utils/shell_utils.py index d8fa8bd..2f1ade0 100644 --- a/llm/utils/shell_utils.py +++ b/llm/utils/shell_utils.py @@ -2,6 +2,7 @@ This module contains utilities to run shell operations namely: remove files, remove folder, move file """ + import os import shutil import glob diff --git a/llm/utils/system_utils.py b/llm/utils/system_utils.py index ab9411d..4a05447 100644 --- a/llm/utils/system_utils.py +++ b/llm/utils/system_utils.py @@ -4,6 +4,7 @@ Attributes: nvidia_smi_cmd (dict): Contains the nvidia-smi command in different operating systems. """ + import os import sys from typing import List diff --git a/llm/utils/tsutils.py b/llm/utils/tsutils.py index 928393a..1da7120 100644 --- a/llm/utils/tsutils.py +++ b/llm/utils/tsutils.py @@ -7,6 +7,7 @@ torch_model_archiver_command (dict): Contains the torch-model-archiver command in different operating systems. """ + import os import sys import platform