diff --git a/docs/model_cookbook.md b/docs/model_cookbook.md index be32f5602..79028ec8d 100644 --- a/docs/model_cookbook.md +++ b/docs/model_cookbook.md @@ -215,6 +215,8 @@ iree-compile /tmp/open_llama_3b_v2/open-llama-3b-v2-f16.mlir \ -o /tmp/open_llama_3b_v2/open-llama-3b-v2-f16_cpu.vmfb ``` +TODO: replace these instructions with the newer shortfin code + Run via `service_v1_cli.py` (shortfin serving, with tokenizer): * TODO: script (via service CLI?) to dump inputs/outputs to .bin/.npy files diff --git a/sharktank/sharktank/serving_poc/__init__.py b/sharktank/sharktank/serving_poc/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/sharktank/sharktank/serving_poc/framework/logging.py b/sharktank/sharktank/serving_poc/framework/logging.py deleted file mode 100644 index fe5ffc069..000000000 --- a/sharktank/sharktank/serving_poc/framework/logging.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import os -import sys - - -# Whether debug assertions are disabled. -NDEBUG: bool = False - -_default_log_level = os.getenv("TURBINE_LOG_LEVEL", "DEBUG") - - -class DefaultFormatter(logging.Formatter): - def __init__(self): - super().__init__( - "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s", - "%m-%d %H:%M:%S", - ) - - -def _setup_logger(): - root_logger = logging.getLogger("sharktank.serving_poc") - root_logger.setLevel(logging.DEBUG) - default_handler = logging.StreamHandler(sys.stderr) - default_handler.flush = sys.stderr.flush - default_handler.setLevel(_default_log_level) - default_handler.setFormatter(DefaultFormatter()) - root_logger.addHandler(default_handler) - root_logger.propagate = False - return root_logger, default_handler - - -root_logger, default_handler = _setup_logger() - -logging.getLogger("asyncio").addHandler(default_handler) - - -def get_logger(name: str): - logger = logging.getLogger(name) - logger.setLevel(_default_log_level) - logger.addHandler(default_handler) - logger.propagate = False - return logger diff --git a/sharktank/sharktank/serving_poc/framework/session.py b/sharktank/sharktank/serving_poc/framework/session.py deleted file mode 100644 index 28af0fd44..000000000 --- a/sharktank/sharktank/serving_poc/framework/session.py +++ /dev/null @@ -1,610 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Runtime session constructs. - -Key concepts: - - * DeviceSession: A single HAL device and other process-level globals. Shared global - memory and corresponding synchronization handles are accessible from here. - * WorkQueue: Logical stream of execution, nested under the DeviceSession. Each - queue holds a timeline semaphore which sequences invocations. For these models, - we route workloads of vastly different characteristics to distinct queues (i.e. - prefill vs decode step). - * LoadedModule: Modules that have been loaded but have not yet been instantiated into - a context. - * HostContext: At least one HostContext is created per LoadedModule. It encapsulates - a VMContext and performs invocations on a dedicated thread. Typically, there will - be more that one HostContext per LoadedModule as it helps us load balance the - host side work across multiple OS threads, ensuring faster feeding of the device. -""" - -from typing import Any, Callable, Coroutine, Generic, TypeVar, Optional, Union - -import asyncio -import concurrent.futures -import math -import queue -from threading import Lock, Thread -import warnings - -import numpy as np - -from iree.runtime import ( # type: ignore[import-untyped] - create_hal_module, - create_io_parameters_module, - get_driver, - BufferUsage, - HalBufferView, - HalCommandBuffer, - HalDevice, - HalDeviceLoopBridge, - HalDriver, - HalElementType, - HalFence, - HalSemaphore, - MemoryType, - ParameterIndex, - VmFunction, - VmInstance, - VmContext, - VmModule, -) - -from .logging import get_logger, NDEBUG - -T = TypeVar("T") - -logger = get_logger("shark_turbine.serving.session") -_CONFIG_LOCK = Lock() -_GLOBAL_VM_INSTANCE: Optional[VmInstance] = None - - -def get_vm_instance() -> VmInstance: - global _GLOBAL_VM_INSTANCE - if not _GLOBAL_VM_INSTANCE: - with _CONFIG_LOCK: - if not _GLOBAL_VM_INSTANCE: - _GLOBAL_VM_INSTANCE = VmInstance() - return _GLOBAL_VM_INSTANCE - - -class DeviceSession: - """Top-level object associated with a single attached device.""" - - __slots__ = [ - "device", - "driver", - "_module_sets", - "queues", - "_queue_request_count", - "vm_instance", - ] - - def __init__( - self, - *, - uri: Optional[str] = None, - driver: Optional[Union[str, HalDriver]] = None, - device: Optional[HalDevice] = None, - vm_instance: Optional[VmInstance] = None, - queue_count: int = 1, - ): - self._queue_request_count = 0 - self.vm_instance = vm_instance or get_vm_instance() - if uri is not None: - assert ( - driver is None and device is None - ), "If 'uri' is given, 'driver' and 'device' cannot be set" - logger.info("Opening device by uri: %s", uri) - driver = uri_driver = get_driver(uri) - device = uri_driver.create_device_by_uri(uri) - assert driver is not None, "'driver' cannot be None" - self.driver = driver if not isinstance(driver, str) else get_driver(driver) - self.device = device if device else self.driver.create_default_device() - - # Dependent objects. - self._module_sets: dict[str, "ModuleSet"] = {} - self.queues = [WorkQueue(self, i) for i in range(queue_count)] - - def shutdown(self): - for ms in self._module_sets.values(): - ms.shutdown() - - def create_module_set(self, name: str, *, context_count: int = 1) -> "ModuleSet": - assert ( - name not in self._module_sets - ), f"Modules with name {name} already created" - lm = ModuleSet(self, name, context_count=context_count) - self._module_sets[name] = lm - return lm - - def module_set(self, name: str) -> "ModuleSet": - try: - return self._module_sets[name] - except KeyError: - raise KeyError( - f"ModuleSet '{name}' not found. Available: {self._module_sets.keys()}" - ) - - def queue(self, index: int = -1) -> "WorkQueue": - """Gets a queue either with an explicit index or in some rotating fashion.""" - if index >= 0: - return self.queues[index] - else: - self._queue_request_count += 1 - qc = self._queue_request_count - return self.queues[qc % len(self.queues)] - - -class ModuleSet: - __slots__ = [ - "contexts", - "modules", - "name", - "session", - "_context_counter", - ] - - def __init__(self, session: DeviceSession, name: str, *, context_count: int): - assert context_count > 0 - self.session = session - self.name = name - self.modules: list[VmModule] = [ - create_hal_module(session.vm_instance, session.device) - ] - self.contexts = [None] * context_count - self._context_counter = 0 - - @property - def initialized(self) -> bool: - return self.contexts[-1] is not None - - def add(self, *modules: VmModule): - for module in modules: - self.modules.append(module) - - def load_vmfb(self, vmfb_path: str): - logger.info("Loading VMFB %s", vmfb_path) - self.add(VmModule.mmap(self.session.vm_instance, vmfb_path)) - - def load_io_module(self, sources_path: str): - logger.info("Loading IO Module %s", sources_path) - index = ParameterIndex() - index.load(sources_path) - par_provider = index.create_provider(scope="model") - self.add(create_io_parameters_module(self.session.vm_instance, par_provider)) - - def initialize(self): - assert not self.initialized, "Already initialized" - count = len(self.contexts) - logger.info("Initializing %s contexts for %s", count, self.name) - for i in range(count): - self.contexts[i] = HostContext( - self.session, self.modules, name=f"HostContext-{self.name}-{i}" - ) - - def shutdown(self): - for hc in self.contexts: - if hc is not None: - hc.shutdown() - - def module(self, name: str) -> VmModule: - for m in self.modules: - if m.name == name: - return m - raise KeyError( - f"Module `{name}` not found. Available: {[m.name for m in self.modules]}" - ) - - def function(self, module_name: str, function_name: str) -> VmFunction: - m = self.module(module_name) - f = m.lookup_function(function_name) - if f is None: - raise KeyError( - f"Function '{function_name}' not found in '{module_name}'. " - f"Available: {m.function_names}" - ) - return f - - @property - def host_context(self) -> "HostContext": - """Gets a context, load balancing across available instances.""" - with _CONFIG_LOCK: - self._context_counter += 1 - counter = self._context_counter - contexts = self.contexts - context = contexts[counter % len(contexts)] - assert context is not None, "Module set not initialized" - return context - - -_ThunkQueueT = queue.SimpleQueue[Union[None, Callable[[], None]]] - - -class HostContext: - def __init__(self, session: DeviceSession, modules: list[VmModule], name: str): - self.session = session - self.vm_context = VmContext(session.vm_instance, modules=modules) - self.name = name - self.loop = asyncio.new_event_loop() - self.loop.set_debug(True) - - # def exc_handler(loop, context): - # print("[EXCEPTION]", loop, context) - # self.loop.set_exception_handler(exc_handler) - - self._device_bridge = HalDeviceLoopBridge(session.device, self.loop) - self._shutdown_future = self.loop.create_future() - logger.info(f"Starting asyncio loop thread %s", name) - self._loop_thread = Thread( - target=self.loop.run_until_complete, - args=[self._shutdown_future], - name=name, - daemon=False, - ) - self._loop_thread.start() - - def shutdown(self, join: bool = True): - if self._shutdown_future is None: - return - logger.info("Signalling shutdown of host context %s", self.name) - local_future = self._shutdown_future - del self._shutdown_future - - def _shutdown(): - local_future.set_result(True) - - self.loop.call_soon_threadsafe(_shutdown) - self._device_bridge.stop() - if join: - self._loop_thread.join() - self.loop.close() - - def __del__(self): - if hasattr(self, "_shutdown_future"): - warnings.warn(f"HostContext deallocated without shutdown(): {self}") - self.shutdown(join=False) - - def run_concurrent( - self, coro: Coroutine[Any, Any, T] - ) -> concurrent.futures.Future[T]: - """Runs a coroutine from another thread, returning a concurrent Future. - - This should be used for submitting initial work to the host context from - another thread or event loop. - - Note that the concurrent Future should have its result() retrieved to - ensure that any asynchronous exceptions are propagated. Otherwise, they will - be silently consumed. - """ - return asyncio.run_coroutine_threadsafe(coro, self.loop) - - def run_sync(self, coro: Coroutine[Any, Any, T]) -> T: - """Runs a coroutine on the host context loop from another thread. - - Waits on and returns the result. - This is primarily intended for testing. - """ - return asyncio.run_coroutine_threadsafe(coro, self.loop).result() - - def on_semaphore( - self, sem: HalSemaphore, payload: int, value: Any - ) -> asyncio.Future: - """Returns an awaitable for when the semaphore attains a payload timepoint. - - The resulting Future will take the given `value` once complete. - """ - return self._device_bridge.on_semaphore(sem, payload, value) - - -class WorkQueue: - """Models a queue as a progression of steps against a timeline semaphore.""" - - __slots__ = [ - "_device", - "_lock", - "_semaphore", - "_step", - "index", - ] - - def __init__(self, session: DeviceSession, index: int = 0): - self.index = index - self._device = session.device - self._lock = Lock() - self._semaphore = session.device.create_semaphore(0) - self._step = 0 - - def execute_sequential(self, command_buffer: HalCommandBuffer): - """Executes a list of command buffers at the current step, advancing to the - next. - """ - with self._lock: - current_step = self._step - next_step = current_step + 1 - self._step = next_step - sem = self._semaphore - self._device.queue_execute( - command_buffer, [(sem, current_step)], [(sem, next_step)] - ) - - def current_fence(self) -> HalFence: - """Gets a fence representing the current step.""" - with self._lock: - return HalFence.create_at(self._semaphore, self._step) - - def step_fences(self) -> tuple[HalFence, HalFence]: - """Gets two fences, one at the current step and one at the next.""" - with self._lock: - current_step = self._step - next_step = current_step + 1 - self._step = next_step - sem = self._semaphore - return HalFence.create_at(sem, current_step), HalFence.create_at(sem, next_step) - - def sync(self, host_context: HostContext) -> asyncio.Future: - """Awaitable that completes when all work currently queued completed.""" - with self._lock: - current_step = self._step - return host_context.on_semaphore(self._semaphore, current_step, True) - - def guard(self, value: T) -> "TimelineGuarded[T]": - """Guards an arbitrary value as a timeline guard at the current queue - position. The value will become available when the queue is sync'd.""" - return TimelineGuarded(value, self._semaphore, self._step) - - def __repr__(self): - with self._lock: - return f"WorkQueue[{self.index}](semaphore={self._semaphore}, step={self._step}" - - -class TransferBuffer: - """Transfer buffers are pairs of host/device buffers of a specific size. - - They are used for streaming to/from the device. - """ - - __slots__ = [ - "host_buffer", - "device_buffer", - "host_buffer_map", - "_pool", - ] - - def __init__(self, session: DeviceSession, buffer_size_bytes: int): - self.host_buffer = session.device.allocator.allocate_buffer( - memory_type=MemoryType.HOST_LOCAL | MemoryType.DEVICE_VISIBLE, - allowed_usage=BufferUsage.DEFAULT, - allocation_size=buffer_size_bytes, - ) - self.device_buffer = session.device.allocator.allocate_buffer( - memory_type=MemoryType.DEVICE_LOCAL, - allowed_usage=BufferUsage.DEFAULT, - allocation_size=buffer_size_bytes, - ) - self.host_buffer_map = self.host_buffer.map() - self._pool: Optional["TransferBufferPool"] = None - - @staticmethod - def allocate_shaped( - session: DeviceSession, shape: list[int], element_type: HalElementType - ) -> "TransferBuffer": - assert HalElementType.is_byte_aligned(element_type) - buffer_size_bytes = math.prod(shape) * HalElementType.dense_byte_count( - element_type - ) - return TransferBuffer(session, buffer_size_bytes) - - def recycle(self): - pool = self._pool - assert ( - pool is not None - ), f"Cannot recycle a TransferBuffer that was not acquired from a pool ({self})" - self._pool = None - pool.recycle(self) - - def h2d_array( - self, - cb: HalCommandBuffer, - shape: list[int], - element_type: HalElementType, - *, - fill_value: Any = None, - ) -> tuple[np.ndarray, HalBufferView]: - """Performs an h2d transfer on the given CommandBuffer of the given shape and - element type. - - Returns a host array and device buffer view. Because transfers do not start - until the command buffer is submitted, the host array should be populated - between the return from this call and submission. - """ - ary = self.host_buffer_map.asarray( - shape, HalElementType.map_to_dtype(element_type) - ) - if fill_value is not None: - ary.fill(fill_value) - bv = HalBufferView(self.device_buffer, shape, element_type) - cb.copy(self.host_buffer, self.device_buffer, length=bv.byte_length) - return ary, bv - - def __repr__(self): - if self._pool is None: - return f"TransferBuffer(FREE)" - else: - return f"TransferBuffer({self._pool})" - - if not NDEBUG: - - def __del__(self): - if self._pool is not None: - warnings.warn( - f"Deallocated TransferBuffer which needed to be recycled: {self}" - ) - - -class TransferBufferPool: - """Pool of transfer buffers of a fixed size.""" - - __slots__ = [ - "_allocator", - "_free_list", - "name", - ] - - def __init__( - self, - allocator: Callable[[], TransferBuffer], - *, - initial_capacity: int, - growable: bool = False, - name: str = "", - ): - self.name = name - if initial_capacity > 0: - self._free_list = [allocator() for _ in range(initial_capacity)] - self._allocator = None - if growable: - self._allocator = allocator - - @staticmethod - def shaped( - session: DeviceSession, - shape: list[int], - element_type: HalElementType, - *, - initial_capacity: int, - growable: bool = False, - name: str = "", - ) -> "TransferBufferPool": - """Allocates a pool of transfer buffers of the given shape.""" - if initial_capacity > 0: - logger.info( - "Allocating initial capacity %s of '%s' transfer buffers: %s x %r", - initial_capacity, - name, - shape, - element_type, - ) - return TransferBufferPool( - lambda: TransferBuffer.allocate_shaped(session, shape, element_type), - initial_capacity=initial_capacity, - growable=growable, - name=name, - ) - - @staticmethod - def sized( - session: DeviceSession, - buffer_byte_size: int, - *, - initial_capacity: int, - growable: bool = False, - name: str = "", - ) -> "TransferBufferPool": - """Allocates a pool of transfer buffers of a given size in bytes.""" - if initial_capacity > 0: - logger.info( - "Allocating initial capacity %s of '%s' transfer buffers: %s bytes", - initial_capacity, - name, - buffer_byte_size, - ) - return TransferBufferPool( - lambda: TransferBuffer(session, buffer_byte_size), - initial_capacity=initial_capacity, - growable=growable, - name=name, - ) - - def acquire(self) -> TransferBuffer: - """Acquires a transfer buffer from the pool. - - Must be returned via recycle() when done. - """ - free_list = self._free_list - if len(free_list) > 0: - tb = free_list.pop() - assert tb._pool is None - tb._pool = self - return tb - - allocator = self._allocator - if not allocator: - raise RuntimeError( - f"Transfer buffer pool '%s' exhausted and not growable", self.name - ) - logger.info("Grow transfer buffer pool '%s'", self.name) - tb = allocator() - assert tb._pool is None - tb._pool = self - return tb - - def recycle(self, tb: TransferBuffer): - """Recycles an acquired transfer buffer.""" - self._free_list.append(tb) - - def __repr__(self): - return f"TransferBufferPool({self.name})" - - -class AsyncResources: - """Resources held for some asynchronous scope.""" - - __slots__ = [ - "_resources", - ] - - def __init__(self): - self._resources: list[Union[TransferBuffer, "AsyncResources"]] = [] - - def acquire_transfer_buffer(self, pool: TransferBufferPool) -> TransferBuffer: - tb = pool.acquire() - self._resources.append(tb) - return tb - - def recycle(self): - for r in self._resources: - r.recycle() - self._resources.clear() - - if not NDEBUG: - - def __del__(self): - if len(self._resources) != 0: - warnings.warn( - f"Deallocated AsyncResources that was not recycled: {self}" - ) - self.recycle() - - -class TimelineGuarded(Generic[T]): - """Some form of results that are structurally available now but will not be - populated until some point in the future. - - This is used to encapsulate entities that are guarded by availability of - a timepoint. Note that we only allow a single timepoint guard in order to - simplify subsequent coordination. This will typically be the case when the - guard is derived from a queue of some form (as opposed to a gather). - """ - - __slots__ = [ - "value", - "sem", - "timeline", - ] - - def __init__(self, value: T, sem: HalSemaphore, timeline: int): - self.value = value - self.sem = sem - self.timeline = timeline - - def resolve(self, host_context: HostContext) -> asyncio.Future[T]: - """Produces an awaitable that resolves to the value once available.""" - return host_context.on_semaphore(self.sem, self.timeline, self.value) - - def __repr__(self): - return f"TimelineGuarded[{self.sem} @ {self.timeline}] = {self.value}" diff --git a/sharktank/sharktank/serving_poc/llm/__init__.py b/sharktank/sharktank/serving_poc/llm/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/sharktank/sharktank/serving_poc/llm/api/rest_server.py b/sharktank/sharktank/serving_poc/llm/api/rest_server.py deleted file mode 100644 index 67536173f..000000000 --- a/sharktank/sharktank/serving_poc/llm/api/rest_server.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Heavily adapted from the vllm api_server.py. - -from typing import AsyncGenerator, Optional, Sequence - -import argparse -import json - -from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse, Response, StreamingResponse -import sys -import uuid -import uvicorn - -from ...framework.logging import get_logger -from ...framework.session import DeviceSession - - -from ..service import ( - create_mock_generate_service, - GenerateService, - GenerateRequest, -) - -logger = get_logger("sharktank.serving_poc.llm.api_server") -app = FastAPI() -service: Optional[GenerateService] = None - - -def get_service() -> GenerateService: - assert service is not None, "Service was not initialized" - return service - - -@app.get("/health") -async def health() -> Response: - get_service() - return Response(status_code=200) - - -@app.post("/generate") -async def generate(request: Request) -> Response: - service = get_service() - r = await request.json() - prompt = r.pop("prompt") - stream = bool(r.pop("stream", False)) - request_id = uuid.uuid4().hex - - generate_request = GenerateRequest(request_id=request_id, prompt=prompt) - result_parts = service.handle_request(generate_request) - - if stream: - # TODO: This isn't entirely matching how others do it: we should be returning - # the full result on each update. - async def stream_contents() -> AsyncGenerator[bytes, None]: - async for part in result_parts: - response_record = json.dumps({"text": part.text}) - yield (response_record + "\0").encode() - - return StreamingResponse(stream_contents()) - - # Non-streaming just reads to the final. - async for result_part in result_parts: - if await request.is_disconnected(): - # Abort. - await service.abort(generate_request.request_id) - return Response(status_code=499) - - assert result_part is not None, "No results generated!" - return JSONResponse({"text": result_part.text}) - - -def main(clargs: Sequence[str]): - global service - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default=None) - parser.add_argument("--port", type=int, default=8000) - parser.add_argument( - "--root-path", - type=str, - default=None, - help="Root path to use for installing behind path based proxy.", - ) - parser.add_argument( - "--timeout-keep-alive", type=int, default=5, help="Keep alive timeout" - ) - parser.add_argument( - "--testing-mock-service", - action="store_true", - help="Enable the mock testing service", - ) - parser.add_argument( - "--device-uri", type=str, default="local-task", help="Device URI to serve on" - ) - - args = parser.parse_args(clargs) - - # Spin up the device machinery. - # Note that in the future, for multi-device, we will need more scaffolding for - # configuration and bringup, obviously. - device_session = DeviceSession(uri=args.device_uri) - - if args.testing_mock_service: - logger.info("Enabling mock LLM generate service") - service = create_mock_generate_service() - - app.root_path = args.root_path - uvicorn.run( - app, - host=args.host, - port=args.port, - log_level="debug", - timeout_keep_alive=args.timeout_keep_alive, - ) - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/sharktank/sharktank/serving_poc/llm/attn_block_cache.py b/sharktank/sharktank/serving_poc/llm/attn_block_cache.py deleted file mode 100644 index a2299c67e..000000000 --- a/sharktank/sharktank/serving_poc/llm/attn_block_cache.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Manages the block cache.""" - -from iree.runtime import ( # type: ignore - HalBufferView, - HalElementType, - BufferUsage, - MemoryType, - PyModuleInterface, - VmModule, -) - -from ..framework.logging import get_logger -from ..framework.session import DeviceSession - -from .config import human_size, CacheParams - - -logger = get_logger("sharktank.serving_poc.llm.cache") - - -class AttnBlockCacheEntry: - __slots__ = [ - "index", - "in_use", - ] - - def __init__(self, index: int): - self.index = index - self.in_use = False - - def __repr__(self): - return f"Block({self.index}, {'FREE' if not self.in_use else 'BUSY'})" - - -class AttnBlockCache: - def __init__(self, session: DeviceSession, cache_params: CacheParams): - self.session = session - self.cache_params = cache_params - self._initialize_block_cache() - - def _initialize_block_cache(self): - model_params = self.cache_params.model - # Allocate the on-device cache slab. - attn_block_count = self.cache_params.device_block_count - attn_block_size_elements = self.cache_params.attn_block_size_elements - attn_block_size_bytes = attn_block_size_elements * model_params.attn_dtype_size - attn_cache_size_bytes = attn_block_count * attn_block_size_bytes - - logger.info("Setting up cache for\n %r", self.cache_params) - logger.info( - "Allocating attention static cache on device of %s " - "(blocks=%s, block_size=%s bytes)", - human_size(attn_cache_size_bytes), - attn_block_count, - attn_block_size_bytes, - ) - self.attn_block_buffer = self.session.device.allocator.allocate_buffer( - memory_type=MemoryType.DEVICE_LOCAL, - allowed_usage=BufferUsage.DEFAULT, - allocation_size=attn_cache_size_bytes, - ) - - # Attn block logical view. - self.attn_block_buffer_view = HalBufferView( - self.attn_block_buffer, - [ - attn_block_count, - attn_block_size_elements, - ], - model_params.attn_dtype, - ) - - # Accounting structs. - self.attn_block_entries = [ - AttnBlockCacheEntry(i) for i in range(attn_block_count) - ] - self.attn_block_free = list(self.attn_block_entries) - - async def acquire_attn_blocks( - self, count: int, into_list: list[AttnBlockCacheEntry] - ): - """Acquires 'count' attention blocks. - - If there are insufficient free blocks, raises an exception. - """ - free_list = self.attn_block_free - assert ( - len(free_list) >= count - ), f"Cache does not contain requested {count} free attn blocks" - for i in range(count): - into_list.append(free_list.pop()) - - async def release_attn_blocks(self, blocks: list[AttnBlockCacheEntry]): - """Releases a list of attention blocks. - - If at all possible, this should be batched to include all blocks that need to - be released at a given time since this will trigger heavy-weight scheduling - that will work better with a view of the new free list as a whole. - """ - free_list = self.attn_block_free - for block in blocks: - free_list.append(block) - - -def create_attn_block_cache_module(attn_block_cache: AttnBlockCache) -> VmModule: - """Creates a VM module that exports the attention block cache. - - For in-system use, we use a dynamic module that can provide the block cache - slab. In other uses, this may be provided by a statically compiled module - that does the same. - - Interface: - Module name: attn_block_cache - Exports: - func @attn_block_cache.get_shared_slab() -> (!hal.buffer_view) - """ - - class Module: - def __init__(self, iface): - ... - - def get_shared_slab(self): - return attn_block_cache.attn_block_buffer_view.ref - - iface = PyModuleInterface(module_name="attn_block_cache", ctor=Module) - iface.export("get_shared_slab", "0v_r", Module.get_shared_slab) - return iface.create() diff --git a/sharktank/sharktank/serving_poc/llm/config.py b/sharktank/sharktank/serving_poc/llm/config.py deleted file mode 100644 index df5db5f8f..000000000 --- a/sharktank/sharktank/serving_poc/llm/config.py +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Configuration objects. - -Parameters that are intrinsic to a specific model. - -In a typical transformer model, the KV cache is organized similar to (mapped to -our parameter names below): - k = tensor.empty(transformer_block_count, batch_size, seq, - attn_head_count, attn_head_dim) - v = ... - -For context, a popular model has parameters of: - attn_dtype_size = 2 # (fp16) - max_seq_len = 2048 - transformer_block_count = 32 - attn_head_count = 32 - attn_head_dim = 128 # (dim / head_count) - -If paging, then we primary care about the organization of a single block, where -a block represents a single position in the sequence for a single item in the batch. -Therefore, it will be organized like: - block = torch.empty(transformer_block_count, 2, attn_head_count, attn_head_dim) - -In this scenario, we declare that one block holds the KV cache for all transformer -block layers because it reduces the accounting. As such, for the above example, -a single position in the sequence will be 524,288 bytes, assuming a 2-byte element -type. If we choose to block by block_stride=16 positions, each block will be 8MiB. -Assuming we wanted to dedicate 12GiB to the block cache, this would equate to 1536 -blocks for a total number of sequence positions of 24,576. - -These are well-known numbers but are derived above to give a sense of scale. - -In order to indirect through to the block cache, we have to provide the index map -to specific invocations: - -* Prefill: Prefill is only writing to the blocks from [0:prompt_len], so it will - need write indices of [batch_size, prompt_len // block_stride + 1]. -* Decode step: Decode is auto-regressive, and needs to first compute the new kv - row and then attend over all rows in the cache up to this point in the sequence. - -If wanting to avoid dynamic allocation of transients, we can also pool the index -tables based on the maximum batch size and maximum sequence length. Since all -block cache sizes are well within the range of an i16, we will use that for storage. -Therefore, each batch invocation would need a block lookup table of: - - byte_size = max_batch_size * (max_seq_len // block_stride) * sizeof(int16_t) - -For a max_batch_size of 16, this is 4KiB of block index table lookups per -invocation. We don't have to statically allocate this, but the system is more -predictable if we just reserve what we need. Again, numbers are given to give a -sense of scale only: real workloads will vary. -""" - -from dataclasses import dataclass - -from iree.runtime import ( # type: ignore - HalElementType, -) - -import json - - -@dataclass -class ModelParams: - """Parameters for a specific compiled model, sufficient to do cache planning and - invocations.""" - - # The element type of the attention caches. - attn_dtype: HalElementType - - # Maximum length of a sequence including prompt and output. - max_seq_len: int - - # Number of transformer blocks. - transformer_block_count: int - - # Number of attention heads per block. - attn_head_count: int - - # Dimensionality of each attention head - attn_head_dim: int - - # Position stride per attention block - block_seq_stride: int - - # Batch sizes that the prefill stage is compiled for. These are expected to be - # functions exported from the model with suffixes of "_bs{batch_size}". Must - # be in ascending order. - prefill_batch_sizes: list[int] - - # Similarly, batch sizes that the decode stage is compiled for. - decode_batch_sizes: list[int] - - # Name of the IREE module implementing the model. - module_name: str = "module" - - # ABI of the module. - module_abi_version: int = 1 - - # Size in bytes of the KV cache dtype. - @property - def attn_dtype_size(self) -> int: - assert HalElementType.is_byte_aligned(self.attn_dtype) - return HalElementType.dense_byte_count(self.attn_dtype) - - @property - def max_prefill_batch_size(self) -> int: - return self.prefill_batch_sizes[-1] - - @property - def max_decode_batch_size(self) -> int: - return self.decode_batch_sizes[-1] - - @property - def max_batch_size(self): - return max(self.max_prefill_batch_size, self.max_decode_batch_size) - - @staticmethod - def load_json(path): - f = open(path) - j = json.load(f) - return ModelParams(attn_dtype=HalElementType.FLOAT_16, **j) - - -@dataclass -class CacheParams: - """Parameters for management of the block cache. - - This is paired with a ModelParams. - - We presently use a static block cache configuration and hand-wave either a tuning - run or pen/paper analysis to derive the parameters. - """ - - model: ModelParams - - # The size of the static block cache on the device. - device_block_count: int - - # The stride of each block in sequence positions. - block_pos_stride: int - - @property - def attn_unit_size_elements(self) -> int: - """Size in bytes of each cache line in the attention cache. - - Each cache line can store a unit position stride. - """ - size = 1 - size *= self.model.transformer_block_count - size *= 2 # K and V cache line - size *= self.model.attn_head_count - size *= self.model.attn_head_dim - return size - - @property - def attn_block_size_elements(self) -> int: - """Size in bytes of each attention block of {block_position_stride} positions.""" - return self.attn_unit_size_elements * self.block_pos_stride - - -@dataclass -class ServiceParams: - """Parameters for the serving service.""" - - cache: CacheParams - model: ModelParams - - -# From: https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size -def human_size(num, suffix="B"): - for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): - if abs(num) < 1024.0: - return f"{num:3.1f}{unit}{suffix}" - num /= 1024.0 - return f"{num:.1f}Yi{suffix}" diff --git a/sharktank/sharktank/serving_poc/llm/impl/service_v1.py b/sharktank/sharktank/serving_poc/llm/impl/service_v1.py deleted file mode 100644 index 8ae0be637..000000000 --- a/sharktank/sharktank/serving_poc/llm/impl/service_v1.py +++ /dev/null @@ -1,495 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Implements the BatchGenerateService for V1 compiled models. - -This is far from where we want to land but is intended for first round bootstrapping. -Perhaps the biggest issue is that it wouldn't mate well as-is with samplers. -""" - -import asyncio -from dataclasses import dataclass - -import numpy as np - -from iree.runtime import ( # type: ignore - HalBufferView, - HalCommandBuffer, - HalElementType, - HalFence, - VmFunction, - VmVariantList, -) - -from ...framework.logging import get_logger, NDEBUG -from ...framework.session import ( - AsyncResources, - DeviceSession, - TimelineGuarded, - TransferBufferPool, - WorkQueue, -) - -from ..attn_block_cache import AttnBlockCacheEntry, AttnBlockCache -from ..config import ServiceParams -from ..service import ( - BatchGenerateService, - BatchGenerateState, - GenerateRequest, -) - - -logger = get_logger("sharktank.serving_poc.llm.impl.service_v1") - -EXPECTED_CONCURRENCY = 10 - - -class GenerateServiceV1(BatchGenerateService): - def __init__( - self, *, session: DeviceSession, params: ServiceParams, cache: AttnBlockCache - ): - self.params = params - self.block_pos_stride = params.cache.block_pos_stride - self.batch_sizes = params.model.prefill_batch_sizes - # TODO: Remove distinction between prefill and decode batch sizes. - assert params.model.decode_batch_sizes == self.batch_sizes - self.session = session - self.cache = cache - module_name = params.model.module_name - logger.info("Configuring serving for module set %s", module_name) - self.module_set = session.module_set(params.model.module_name) - - # Initialize prefill entry-points (1 per batch size). - self.prefill_functions: dict[int, VmFunction] = {} - for bs in self.batch_sizes: - assert bs not in self.prefill_functions - symbol_name = f"prefill_bs{bs}" - logger.info("Looking up symbol '%s'", symbol_name) - self.prefill_functions[bs] = self.module_set.function( - module_name, symbol_name - ) - - # Initialize decode entry-points (1 per batch size). - self.decode_functions: dict[int, VmFunction] = {} - for bs in self.batch_sizes: - assert bs not in self.decode_functions - symbol_name = f"decode_bs{bs}" - logger.info("Looking up symbol '%s'", symbol_name) - self.decode_functions[bs] = self.module_set.function( - module_name, symbol_name - ) - - self._initialize_transfer_pools() - - def _initialize_transfer_pools(self): - params = self.params - max_bs = params.model.max_batch_size - max_sl = params.model.max_seq_len - initial_inflight = EXPECTED_CONCURRENCY - - # block_indices_pool: array([max_batch_size, max_attn_blocks], np.int64) - # Suitable to handle the sequence->block mapping for all steps. - self.block_indices_pool = TransferBufferPool.shaped( - self.session, - [ - max_bs, - max_sl // self.block_pos_stride, - ], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="block_cache_indices", - ) - - # Prefill tokens: array([max_batch_size, max_seq_len], np.int64) - # Tokens inputs to prefill. - self.prefill_tokens_pool = TransferBufferPool.shaped( - self.session, - [ - max_bs, - max_sl, - ], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="prefill_tokens", - ) - - # Prefill sequence lengths: array([max_batch_size], np.int64) - # Sequence lengths of input tokens. - self.prefill_seq_lens_pool = TransferBufferPool.shaped( - self.session, - [max_bs], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="prefill_seq_lens", - ) - - # Decode tokens: array([max_batch_size], np.int64) - # Tokens to perform a decode step with. - self.decode_tokens_pool = TransferBufferPool.shaped( - self.session, - [max_bs, 1], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="decode_tokens", - ) - - # Decode seq lengths: array([max_batch_size], np.int64) - # Decoder seq length for this step - self.decode_seq_lens_pool = TransferBufferPool.shaped( - self.session, - [max_bs], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="decode_seq_len", - ) - - # Decode start positions: array([max_batch_size], np.int64) - # Tokens to perform a decode step with. - self.decode_start_pos_pool = TransferBufferPool.shaped( - self.session, - [max_bs], - HalElementType.SINT_64, - initial_capacity=initial_inflight, - growable=True, - name="decode_start_pos", - ) - - def start(self) -> "GenerateState": - return GenerateState(self) - - def shutdown(self): - self.session.shutdown() - - -class _Sequence: - __slots__ = [ - "attn_blocks", - "attn_blocks_needed", - "current_token_ids", - "decode_token_ids", - "request", - "seq_length", - ] - - current_token_ids: list[int] - decode_token_ids: list[int] - - def __init__(self, request: GenerateRequest): - self.request = request - self.seq_length: int = 0 - self.attn_blocks: list[AttnBlockCacheEntry] = [] - self.attn_blocks_needed: int = 0 - self.decode_token_ids = [] - self.current_token_ids = [] - - def attn_blocks_available(self): - return len(self.attn_blocks) - - def resize_attention(self, new_size): - old_size = self.attn_blocks_needed - self.attn_blocks_needed = new_size - return new_size - old_size - - -class GenerateState(BatchGenerateState): - __slots__ = [ - "_bs", - "_decode_function", - "_prefill_function", - "_max_attn_blocks_length", - "_max_seq_length", - "_resources", - "_service", - "_sequences", - "_batch_queue", - ] - - def __init__(self, service: GenerateServiceV1): - super().__init__(service.module_set.host_context) - self._resources = AsyncResources() - self._service = service - self._sequences: list[_Sequence] = [] - self._batch_queue = WorkQueue(service.session) - - async def recycle(self): - """Recycles or releases all resources consumed by this instance.""" - cache = self._service.cache - self._batch_queue.sync(self.host_context) - self._resources.recycle() - all_blocks = [] - for seq in self._sequences: - all_blocks.extend(seq.attn_blocks) - seq.attn_blocks.clear() - self._sequences = [] - await cache.release_attn_blocks(all_blocks) - - async def set_sequences(self, requests: list[GenerateRequest]): - """Initiates processing of a list of sequences that make up a batch. - - This is async because it acquires resources which may not be available. - """ - service = self._service - block_pos_stride = service.block_pos_stride - - # Loop through each request and reserve initial attention blocks. - bs = 0 - sequences = self._sequences - assert not sequences, "set_sequences already called" - max_attn_blocks_length = 0 - max_seq_length = 0 - attn_blocks_required = 0 - - for req in requests: - bs += 1 - seq = _Sequence(req) - sequences.append(seq) - seq.current_token_ids = req.required_prompt_token_ids - seq_length = len(seq.current_token_ids) - seq.seq_length = seq_length - max_seq_length = max(max_seq_length, seq_length) - initial_block_count = seq_length // block_pos_stride + 1 - attn_blocks_required += initial_block_count - seq.attn_blocks_needed = initial_block_count - max_attn_blocks_length = max(max_attn_blocks_length, initial_block_count) - - # Determine the appropriate batched entrypoints. - assert bs > 0 - for allowed_bs in service.batch_sizes: - if allowed_bs >= bs: - self._prefill_function = service.prefill_functions[allowed_bs] - self._decode_function = service.decode_functions[allowed_bs] - break - else: - raise AssertionError(f"Unsupported batch size: {bs}") - - # Acquire the needed attention blocks in one batch so as to give the scheduler - # the most visibility into the need. - logger.debug("Acquire prefill attn blocks: %s", attn_blocks_required) - all_attn_blocks: list[AttnBlockCacheEntry] = [] - await service.cache.acquire_attn_blocks(attn_blocks_required, all_attn_blocks) - block_index = 0 - for seq in sequences: - next_block_count = seq.attn_blocks_needed - seq.attn_blocks.extend( - all_attn_blocks[block_index : block_index + seq.attn_blocks_needed] - ) - block_index += next_block_count - - # Save state. - self._bs = allowed_bs - self._max_attn_blocks_length = max_attn_blocks_length - self._max_seq_length = max_seq_length - - async def prefill(self) -> TimelineGuarded[HalBufferView]: - hc = self.host_context - service = self._service - resources = self._resources - bs = self._bs - service = self._service - block_pos_stride = service.block_pos_stride - max_attn_blocks_length = self._max_attn_blocks_length - max_seq_length = max_attn_blocks_length * block_pos_stride - sequences = self._sequences - work_queue = self._batch_queue - - # Record a command buffer for performing h2d transfers. - cb = HalCommandBuffer(hc.session.device) - - # Prepare input tokens, sequence lengths and block indices. - # We acquire a transfer buffer of each from the respective pool, populate its - # host side and enqueue. - # prefill_tokens: array([bs, max_seq_length], np.int32) - prefill_tokens_host, prefill_tokens_device = resources.acquire_transfer_buffer( - service.prefill_tokens_pool - ).h2d_array(cb, [bs, max_seq_length], HalElementType.SINT_64, fill_value=0) - - # prefill_seq_lens: array([bs], np.int32) - ( - prefill_seq_lens_host, - prefill_seq_lens_device, - ) = resources.acquire_transfer_buffer(service.prefill_seq_lens_pool).h2d_array( - cb, [bs], HalElementType.SINT_64, fill_value=0 - ) - - # attn_block_indices: array([bs, max_attn_blocks], np.in16) - ( - prefill_attn_block_indices_host, - prefill_attn_block_indices_device, - ) = resources.acquire_transfer_buffer(service.block_indices_pool).h2d_array( - cb, [bs, max_attn_blocks_length], HalElementType.SINT_64, fill_value=0 - ) - - # Populate host buffers for each sequence. - for i in range(len(sequences)): - seq = sequences[i] - attn_blocks = seq.attn_blocks - current_token_ids = seq.current_token_ids - row_seq_len = len(current_token_ids) - prefill_tokens_host[i, 0:row_seq_len] = current_token_ids - prefill_seq_lens_host[i] = row_seq_len - for j in range(len(seq.attn_blocks)): - prefill_attn_block_indices_host[i, j] = attn_blocks[j].index - - # Perform h2d transfers. - cb.end() - work_queue.execute_sequential(cb) - - # Inputs: - # token_ids - # seq_lens - # attn_block_indices - # attn_block_buffer_view (the entire slab passed as input) - # wait, signal semaphores - # tied attn_block_buffer (for input[2]) - # tied attn_block_buffer (for result[0]) - inputs = VmVariantList(3) - inputs.push_ref(prefill_tokens_device) - inputs.push_ref(prefill_seq_lens_device) - inputs.push_ref(prefill_attn_block_indices_device) - inputs.push_ref(service.cache.attn_block_buffer_view) - - # Outputs: - # attn_block_buffer_view (tied output) - # decode_tokens - outputs = VmVariantList(1) - # TODO: Async invoke. - hc.vm_context.invoke(self._prefill_function, inputs, outputs) - return work_queue.guard(outputs.get_as_ref(0).deref(HalBufferView)) - - async def set_decode_step(self, tokens): - """Initiates processing of a list of tokens to decode across each batch - - This is async because it acquires resources which may not be available. - """ - service = self._service - block_pos_stride = service.block_pos_stride - - sequences = self._sequences - assert sequences, "set_sequences was not called yet" - assert len(sequences) == len(tokens), "expected token for each sequence" - - max_attn_blocks_length = 0 - max_seq_length = 0 - attn_blocks_required = 0 - - for tok, seq in zip(tokens, self._sequences): - seq.decode_token_ids.append(tok) - seq.seq_length = seq.seq_length + 1 - - max_seq_length = max(max_seq_length, seq.seq_length) - block_count = seq.seq_length // block_pos_stride + 1 - - seq.attn_blocks_needed = block_count - attn_blocks_required += block_count - seq.attn_blocks_available() - max_attn_blocks_length = max(max_attn_blocks_length, block_count) - - # Acquire the needed attention blocks in one batch so as to give the scheduler - # the most visibility into the need. - logger.debug("Acquire decode attn blocks: %s", attn_blocks_required) - all_attn_blocks: list[AttnBlockCacheEntry] = [] - await service.cache.acquire_attn_blocks(attn_blocks_required, all_attn_blocks) - block_index = 0 - for seq in sequences: - next_block_count = seq.attn_blocks_needed - seq.attn_blocks_available() - seq.attn_blocks.extend( - all_attn_blocks[block_index : block_index + next_block_count] - ) - block_index += next_block_count - - # Save state. - self._max_attn_blocks_length = max_attn_blocks_length - self._max_seq_length = max_seq_length - - async def decode(self) -> TimelineGuarded[HalBufferView]: - hc = self.host_context - service = self._service - resources = self._resources - bs = self._bs - max_attn_blocks_length = self._max_attn_blocks_length - sequences = self._sequences - work_queue = self._batch_queue - - # Record a command buffer for performing h2d transfers. - cb = HalCommandBuffer(hc.session.device) - - # decode_tokens: array([bs, 1], np.int32) - (decode_tokens_host, decode_tokens_device,) = resources.acquire_transfer_buffer( - service.decode_tokens_pool - ).h2d_array(cb, [bs, 1], HalElementType.SINT_64, fill_value=0) - - # decode_seq_lens: array([bs], np.int32) - ( - decode_seq_lens_host, - decode_seq_lens_device, - ) = resources.acquire_transfer_buffer(service.decode_seq_lens_pool).h2d_array( - cb, [bs], HalElementType.SINT_64, fill_value=0 - ) - - # decode_start_pos: array([bs], np.int32) - ( - decode_start_pos_host, - decode_start_pos_device, - ) = resources.acquire_transfer_buffer(service.decode_start_pos_pool).h2d_array( - cb, [bs], HalElementType.SINT_64, fill_value=0 - ) - - # attn_block_indices: array([bs, max_attn_blocks], np.in16) - ( - decode_attn_block_indices_host, - decode_attn_block_indices_device, - ) = resources.acquire_transfer_buffer(service.block_indices_pool).h2d_array( - cb, [bs, max_attn_blocks_length], HalElementType.SINT_64, fill_value=0 - ) - - # Populate host buffers for each sequence. - for i in range(len(sequences)): - seq = sequences[i] - attn_blocks = seq.attn_blocks - - tok = seq.decode_token_ids[0] - seq_len = len(seq.current_token_ids) - print(f"seq.current_token_ids: {seq.current_token_ids}") - seq.current_token_ids.append(tok) - seq.decode_token_ids = seq.decode_token_ids[1:] - - decode_tokens_host[i, 0] = tok - decode_start_pos_host[i] = seq_len - decode_seq_lens_host[i] = seq_len - for j in range(len(seq.attn_blocks)): - decode_attn_block_indices_host[i, j] = attn_blocks[j].index - - # Perform h2d transfers. - cb.end() - work_queue.execute_sequential(cb) - - # Inputs: - # token_ids - # seq_lens - # start_pos - # attn_block_indices - # attn_block_buffer_view (the entire slab passed as input) - # wait, signal semaphores - # tied attn_block_buffer (for input[4]) - # tied attn_block_buffer (for result[0]) - inputs = VmVariantList(5) - inputs.push_ref(decode_tokens_device) - inputs.push_ref(decode_seq_lens_device) - inputs.push_ref(decode_start_pos_device) - inputs.push_ref(decode_attn_block_indices_device) - inputs.push_ref(service.cache.attn_block_buffer_view) - - # Outputs: - # attn_block_buffer_view (tied output) - # decode_tokens - outputs = VmVariantList(1) - # TODO: Async invoke. - hc.vm_context.invoke(self._decode_function, inputs, outputs) - return work_queue.guard(outputs.get_as_ref(0).deref(HalBufferView)) diff --git a/sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py b/sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py deleted file mode 100644 index 7895341c9..000000000 --- a/sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import asyncio -import argparse -import numpy -import sys - -from transformers import LlamaTokenizer # type: ignore - -from iree.runtime import ( # type: ignore - HalElementType, -) - -from sharktank.serving_poc.framework.session import DeviceSession - -from sharktank.serving_poc.llm.attn_block_cache import ( - create_attn_block_cache_module, - AttnBlockCache, -) - -from sharktank.serving_poc.llm.config import ( - CacheParams, - ModelParams, - ServiceParams, -) - -from sharktank.serving_poc.llm.impl.service_v1 import GenerateServiceV1 -from sharktank.serving_poc.llm.service import GenerateRequest - - -def setup(vmfb_path, config_path, gguf_path): - from iree.runtime._binding import disable_leak_checker # type: ignore - - model_params = ModelParams.load_json(config_path) - - device_block_count = model_params.max_seq_len // model_params.block_seq_stride - cache_params = CacheParams( - model=model_params, - device_block_count=device_block_count, - block_pos_stride=model_params.block_seq_stride, - ) - - disable_leak_checker() - session = DeviceSession(uri="local-sync", queue_count=2) - attn_block_cache = AttnBlockCache(session, cache_params) - - lms = session.create_module_set(model_params.module_name, context_count=1) - lms.load_io_module(gguf_path) - lms.load_vmfb(vmfb_path) - lms.add(create_attn_block_cache_module(attn_block_cache)) - lms.initialize() - - params = ServiceParams(cache=cache_params, model=model_params) - service = GenerateServiceV1(session=session, params=params, cache=attn_block_cache) - return service - - -def map_buffer(value): - mapped = value.map() - return mapped.asarray(value.shape, HalElementType.map_to_dtype(value.element_type)) - - -async def main(argv): - parser = argparse.ArgumentParser() - parser.add_argument("--tokenizer", help="name of hugginface tokenizer to use") - parser.add_argument("--config", help="json config file with hyperparameters") - parser.add_argument("--vmfb", help="vmfb with compiler LLM kernels") - parser.add_argument("--gguf", help="gguf file containing modle coefficients") - parsed = parser.parse_args(argv) - - hf_path = parsed.tokenizer - config_path = parsed.config - vmfb_path = parsed.vmfb - gguf_path = parsed.gguf - - service = setup(vmfb_path, config_path, gguf_path) - tokenizer = LlamaTokenizer.from_pretrained(hf_path) - state = service.start() - - for line in ["one two three four five six seven eight"]: - prompt = line.strip() - if not prompt: - break - - input_ids = tokenizer.encode(prompt, return_tensors="pt")[0].tolist() - print(input_ids) - request = GenerateRequest("request_id", prompt, input_ids) - await state.set_sequences([request]) - logits = await state.prefill() - - seq_len = len(input_ids) - mapped_logits = map_buffer(logits.value) - predicted_tokens = numpy.argmax(mapped_logits[0, :seq_len], axis=-1) - predicted_token = predicted_tokens[-1] - decoded_token = tokenizer.decode(predicted_token) - print(f"Prefill predicted token: {predicted_token}, decoded: '{decoded_token}'") - - # TODO(scotttodd): sanity check tokenizer use, document inputs/outputs - # 'prefill' is for initialization with multiple steps at once - # 'decode' is for hypothesis exploration, one step at a time - await state.set_decode_step([predicted_token]) - logits = await state.decode() - mapped_logits = map_buffer(logits.value) - predicted_tokens = numpy.argmax(mapped_logits, axis=-1) - predicted_token = predicted_tokens[0] - decoded_token = tokenizer.decode(predicted_token) - print(f"Decode predicted token: {predicted_token}, decoded: '{decoded_token}'") - await state.recycle() - - service.shutdown() - - -if __name__ == "__main__": - asyncio.run(main(sys.argv[1:])) diff --git a/sharktank/sharktank/serving_poc/llm/service.py b/sharktank/sharktank/serving_poc/llm/service.py deleted file mode 100644 index c5d4ffb44..000000000 --- a/sharktank/sharktank/serving_poc/llm/service.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import AsyncIterator, Callable, Optional - -from abc import abstractmethod, ABC -import asyncio -from dataclasses import dataclass - -from ..framework.session import ( - HostContext, -) - -######################################################################################## -# User-level single request service -######################################################################################## - - -@dataclass -class GenerateRequest: - """Encapsulates a request to perform LLM generation. - - Requests are bootstrapped from user values and then pumped through the pipeline, - receiving additional elaboration needed to actually begin generation. - """ - - # Client set fields - request_id: str - prompt: str - - # Fields that are set as the request is processed. - prompt_token_ids: Optional[list[int]] = None - - @property - def required_prompt_token_ids(self) -> list[int]: - ids = self.prompt_token_ids - assert ids is not None - return ids - - -@dataclass -class GenerateResponsePart: - """A response part from an LLM generation request.""" - - request: GenerateRequest - index: int - token_ids: list[int] - - # Fields that can be set as the response is post-processed. - text: Optional[str] = None - finished: bool = False - - -class GenerateService(ABC): - """Asynchronous generator service which processes requests into response parts.""" - - @abstractmethod - def handle_request( - self, - request: GenerateRequest, - ) -> AsyncIterator[GenerateResponsePart]: - """Generates response parts for a request.""" - ... - - @abstractmethod - async def abort(self, request_id: str) -> None: - """Aborts a submitted request.""" - ... - - -######################################################################################## -# Batch generation service -# This service is completely asynchronous and operates on a BatchGenerateRequest as -# a state machine. It is expected to have an external actor stepping it through -# states. -######################################################################################## - - -class BatchGenerateService(ABC): - """Handles generation of a batch of requests.""" - - __slots__ = [] # type: ignore - - # def start_prefill(self, request: BatchGenerateRequest): - # ... - @abstractmethod - def start(self) -> "BatchGenerateState": - ... - - -class BatchGenerateState(ABC): - """In-progress batch generation state.""" - - __slots__ = [ - "host_context", - ] - - def __init__(self, host_context: HostContext): - self.host_context = host_context - - -######################################################################################## -# Utilities -######################################################################################## - - -class SyncGenerateFilter(GenerateService): - """GenerateService filter which can synchronously pre/post process.""" - - __slots__ = ["_next"] - - def __init__(self, next: GenerateService): - self._next = next - - def filter_request(self, request: GenerateRequest): - ... - - def filter_response(self, part: GenerateResponsePart): - ... - - async def handle_request( - self, - request: GenerateRequest, - ) -> AsyncIterator[GenerateResponsePart]: - self.filter_request(request) - async for part in self._next.handle_request(request): - self.filter_response(part) - yield part - - async def abort(self, request_id: str) -> None: - """Aborts a submitted request.""" - await self._next.abort(request_id) - - -######################################################################################## -# Testing and mock types -######################################################################################## - - -def create_mock_generate_service() -> GenerateService: - return DummyTokenizerService(EchoGenerateService()) - - -class DummyTokenizerService(SyncGenerateFilter): - """Tokenizer service which will map to code points. - - Useful for testing. - """ - - def filter_request(self, request: GenerateRequest): - if request.prompt_token_ids is None: - request.prompt_token_ids = [ord(c) for c in request.prompt] - - def filter_response(self, part: GenerateResponsePart): - if part.text is None: - part.text = "".join([chr(x) for x in part.token_ids]) - - -class EchoGenerateService(GenerateService): - """Dummy implementation of a generate service. - - It just echoes back the request five times after a delay. - """ - - def __init__(self, delay: float = 0.1): - self._delay = delay - - async def handle_request( - self, - request: GenerateRequest, - ) -> AsyncIterator[GenerateResponsePart]: - next = None - for i in range(5): - if next: - yield next - assert request.prompt_token_ids, "Request lacks prompt tokens" - next = GenerateResponsePart( - request, i, request.prompt_token_ids, finished=False - ) - await asyncio.sleep(self._delay) - if next: - next.finished = True - yield next - - async def abort(self, request_id: str) -> None: - pass diff --git a/sharktank/sharktank/serving_poc/llm/testing/fake_v1_module.py b/sharktank/sharktank/serving_poc/llm/testing/fake_v1_module.py deleted file mode 100644 index a36ebe667..000000000 --- a/sharktank/sharktank/serving_poc/llm/testing/fake_v1_module.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Implements a service_v1 compliant module in Python for testing. - -This uses a PyModuleInterface to define a fake VmModule that exposes 'prefill_bs{n}' -and 'decode_bs{n}' such that the call sequence and args/results can be manipulated. -""" - -import numpy as np -import textwrap -import threading - -from iree.runtime import ( # type: ignore - BufferUsage, - HalBuffer, - HalBufferView, - HalDevice, - HalElementType, - HalFence, - MemoryType, - PyModuleInterface, - VmModule, - VmRef, -) - -from ..config import ModelParams - - -def create_fake_module( - device: HalDevice, module_name: str, model_params: ModelParams -) -> VmModule: - class ServiceV1Module: - def __init__(self, iface): - ... - print("IFACE:", iface, dir(iface)) - - def prefill( - self, - bs: int, - token_ids_ref: VmRef, - seq_lens_ref: VmRef, - attn_block_indices_ref: VmRef, - attn_block_buffer_view: VmRef, - ): - result_array: np.ndarray = np.ndarray([bs, 1], dtype=np.int32) - - def run(): - print(f"FAKE_V1_MODULE: PREFILL bs={bs} : WAIT") - print(" - READY") - _format_device_buffer_view( - lambda s: print(" token_ids =", s), token_ids_ref - ) - _format_device_buffer_view( - lambda s: print(" seq_lens =", s), seq_lens_ref - ) - _format_device_buffer_view( - lambda s: print(" attn_block_indices =", s), - attn_block_indices_ref, - ) - _format_device_buffer_view( - lambda s: print(" attn_block_buffer_view =", s), - attn_block_buffer_view, - ) - - # Async populate. - device_array = result_bv.map().asarray( - result_array.shape, result_array.dtype - ) - for i in range(bs): - device_array[i, 0] = i + 1 - - threading.Thread(target=run).start() - - result_buffer = device.allocator.allocate_buffer( - memory_type=MemoryType.DEVICE_LOCAL | MemoryType.HOST_VISIBLE, - allowed_usage=BufferUsage.DEFAULT, - allocation_size=result_array.size * result_array.itemsize, - ) - result_bv = HalBufferView( - result_buffer, result_array.shape, HalElementType.INT_32 - ) - return result_bv.ref - - def decode(self, bs: int): - print(f"FAKE_V1_MODULE: DECODE bs={bs}") - - iface = PyModuleInterface(module_name=module_name, ctor=ServiceV1Module) - - # Dynamically define prefill functions. - def add_prefill_bs(bs: int): - def trampoline(self, *args): - return self.prefill(bs, *args) - - iface.export(f"prefill_bs{bs}", "0rrrr_r", trampoline) - - [add_prefill_bs(bs) for bs in model_params.prefill_batch_sizes] - - # Dynamically define decode functions. - def add_decode_bs(bs: int): - def trampoline(self, *args): - return self.decode(bs, *args) - - iface.export(f"decode_bs{bs}", "0v_v", trampoline) - - [add_decode_bs(bs) for bs in model_params.decode_batch_sizes] - - return iface.create() - - -def _format_device_buffer_view(callback, bv_ref: VmRef): - bv = bv_ref.deref(HalBufferView) # type: HalBufferView - value = bv.map().asarray(bv.shape, HalElementType.map_to_dtype(bv.element_type)) - value_indented = textwrap.indent(repr(value), " ") - callback(f"{bv!r}\n{value_indented}") diff --git a/sharktank/sharktank/serving_poc/py.typed b/sharktank/sharktank/serving_poc/py.typed deleted file mode 100644 index 5e43cc13b..000000000 --- a/sharktank/sharktank/serving_poc/py.typed +++ /dev/null @@ -1 +0,0 @@ -# Marker file for PEP 561 inline type checking. diff --git a/sharktank/tests/serving_poc/framework/device_session_test.py b/sharktank/tests/serving_poc/framework/device_session_test.py deleted file mode 100644 index 5dfdd5f46..000000000 --- a/sharktank/tests/serving_poc/framework/device_session_test.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import pytest - -from sharktank.serving_poc.framework.session import ( - DeviceSession, -) - - -@pytest.fixture -def local_device_session(): - session = DeviceSession(uri="local-task") - yield session - session.shutdown() - - -def test_start_shutdown_no_host_contexts(local_device_session: DeviceSession): - ms = local_device_session.create_module_set("default") - ms.initialize() - - -def test_host_context_start_stop(local_device_session: DeviceSession): - ms = local_device_session.create_module_set("default") - ms.initialize() - hc = ms.host_context - - -def test_host_context_scheduling(local_device_session: DeviceSession): - device = local_device_session.device - ms = local_device_session.create_module_set("default") - ms.initialize() - hc = ms.host_context - - sem = device.create_semaphore(0) - - async def task1(): - print("[coro1] test_host_context_scheduling.task") - await hc.on_semaphore(sem, 1, True) - print("[coro1] await completed") - sem.signal(2) - - async def task2(): - print("[coro2] waiting for 2") - await hc.on_semaphore(sem, 2, True) - sem.fail("Fail from task2") - - f1 = hc.run_concurrent(task1()) - f2 = hc.run_concurrent(task2()) - sem.signal(1) - print("[main] Waiting for semaphore") - - # Ensure task completion. Important to consume to ensure that exceptions - # propagate. - f1.result() - f2.result() - - print("[main] Waiting on semaphore payload 3") - with pytest.raises(Exception, match="Fail from task2"): - sem.wait(3) diff --git a/sharktank/tests/serving_poc/llm/api_server_test.py b/sharktank/tests/serving_poc/llm/api_server_test.py deleted file mode 100644 index c2d2cc36a..000000000 --- a/sharktank/tests/serving_poc/llm/api_server_test.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import os -from contextlib import closing -from pathlib import Path -import pytest -import requests -import socket -import subprocess -import sys -import time - - -def find_free_port(): - """This tries to find a free port to run a server on for the test. - - Race conditions are possible - the port can be acquired between when this - runs and when the server starts. - - https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number - """ - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: - s.bind(("localhost", 0)) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - return s.getsockname()[1] - - -class ServerRunner: - def __init__(self, args): - port = str(find_free_port()) - self.url = "http://localhost:" + port - env = os.environ.copy() - env["PYTHONUNBUFFERED"] = "1" - self.process = subprocess.Popen( - [ - sys.executable, - "-m", - "sharktank.serving_poc.llm.api.rest_server", - "--testing-mock-service", - "--port=" + port, - ] - + args, - env=env, - # TODO: Have a more robust way of forking a subprocess. - cwd=str(Path(__file__).resolve().parent.parent.parent), - stdout=sys.stdout, - stderr=sys.stderr, - ) - self._wait_for_ready() - - def _wait_for_ready(self): - start = time.time() - while True: - try: - if requests.get(f"{self.url}/health").status_code == 200: - return - except Exception as e: - if self.process.poll() is not None: - raise RuntimeError("API server processs terminated") from e - time.sleep(1.0) - if (time.time() - start) > 30: - raise RuntimeError("Timeout waiting for server start") - - def __del__(self): - try: - process = self.process - except AttributeError: - pass - else: - process.terminate() - process.wait() - - -@pytest.fixture(scope="session") -def server(): - try: - import fastapi - import uvicorn - except ModuleNotFoundError as e: - pytest.skip(f"Skipping server test because deps are missing: {e}") - runner = ServerRunner([]) - yield runner - - -def test_health(server: ServerRunner): - # Health check is part of getting the fixture. - ... - - -def test_generate_non_streaming(server: ServerRunner): - resp = requests.post( - f"{server.url}/generate", - json={ - "prompt": "Hi Bob", - }, - ) - resp.raise_for_status() - d = resp.json() - assert d["text"] == "Hi Bob", repr(d) - - -def test_generate_streaming(server: ServerRunner): - resp = requests.post( - f"{server.url}/generate", json={"prompt": "Hi Bob!", "stream": True} - ) - resp.raise_for_status() - full_contents = resp.content - expected_contents = b'{"text": "Hi Bob!"}\x00' * 5 - assert ( - full_contents == expected_contents - ), f"Expected {expected_contents!r} vs {full_contents!r}" diff --git a/sharktank/tests/serving_poc/llm/service_v1_test.py b/sharktank/tests/serving_poc/llm/service_v1_test.py deleted file mode 100644 index c010e2034..000000000 --- a/sharktank/tests/serving_poc/llm/service_v1_test.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import pytest - -from iree.runtime import ( # type: ignore - HalElementType, -) - -from sharktank.serving_poc.framework.session import DeviceSession -from sharktank.serving_poc.llm.config import ( - CacheParams, - ModelParams, - ServiceParams, -) - -from sharktank.serving_poc.llm.service import ( - GenerateRequest, - GenerateResponsePart, -) - -from sharktank.serving_poc.llm.attn_block_cache import ( - create_attn_block_cache_module, - AttnBlockCache, -) - -from sharktank.serving_poc.llm.impl.service_v1 import ( - GenerateServiceV1, -) - -from sharktank.serving_poc.llm.testing.fake_v1_module import ( - create_fake_module, -) - - -@pytest.fixture -def cache_params(model_params: ModelParams) -> CacheParams: - return CacheParams(model=model_params, device_block_count=128, block_pos_stride=16) - - -@pytest.fixture -def model_params() -> ModelParams: - return ModelParams( - module_name="AwesomeLLM", - module_abi_version=1, - attn_dtype=HalElementType.FLOAT_16, - max_seq_len=128, - transformer_block_count=32, - attn_head_count=32, - attn_head_dim=128, - block_seq_stride=16, - prefill_batch_sizes=[1, 4, 16], - decode_batch_sizes=[1, 4, 16], - ) - - -@pytest.fixture -def uninitialized_session(model_params: ModelParams): - from iree.runtime._binding import disable_leak_checker # type: ignore - - disable_leak_checker() - session = DeviceSession(uri="local-task", queue_count=2) - yield session - session.shutdown() - del session - - -@pytest.fixture -def attn_block_cache( - uninitialized_session: DeviceSession, cache_params: CacheParams -) -> AttnBlockCache: - return AttnBlockCache(uninitialized_session, cache_params) - - -@pytest.fixture -def session( - model_params: ModelParams, - uninitialized_session: DeviceSession, - attn_block_cache: AttnBlockCache, -): - session = uninitialized_session - lms = session.create_module_set("AwesomeLLM", context_count=1) - lms.add( - create_attn_block_cache_module(attn_block_cache), - create_fake_module(session.device, "AwesomeLLM", model_params=model_params), - ) - lms.initialize() - return session - - -@pytest.fixture -def service( - session: DeviceSession, - cache_params: CacheParams, - model_params: ModelParams, - attn_block_cache: AttnBlockCache, -): - params = ServiceParams(cache=cache_params, model=model_params) - return GenerateServiceV1(session=session, params=params, cache=attn_block_cache) - - -def test_single(service: GenerateServiceV1): - state = service.start() - - async def task(): - await state.set_sequences( - requests=[ - GenerateRequest( - "1", - "hello, tell me a story", - [3, 4, 5, 12, 23, 88, 10, 2, 5, 9, 12, 13, 99, 56, 33, 124, 73], - ), - GenerateRequest("2", "goodbye", [9, 10]), - ] - ) - guarded_outputs = await state.prefill() - prefill_ids = await guarded_outputs.resolve(state.host_context) - print( - "PREFILL IDS:", - prefill_ids, - ":\n", - prefill_ids.map().asarray( - prefill_ids.shape, HalElementType.map_to_dtype(prefill_ids.element_type) - ), - ) - await state.recycle() - - state.host_context.run_sync(task())