From 21b2fd14099d6fba17edb1c9a31452053d94c07a Mon Sep 17 00:00:00 2001 From: derekdeming Date: Sat, 12 Oct 2024 00:28:24 -0400 Subject: [PATCH] adding type hints to functions --- swarm/core.py | 8 +++----- swarm/repl/repl.py | 9 +++++++-- swarm/types.py | 6 +++--- swarm/util.py | 7 ++++--- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/swarm/core.py b/swarm/core.py index 65dedf1a..ce180307 100644 --- a/swarm/core.py +++ b/swarm/core.py @@ -2,7 +2,7 @@ import copy import json from collections import defaultdict -from typing import List, Callable, Union +from typing import List, Callable, Union, Optional # Package/library imports from openai import OpenAI @@ -24,10 +24,8 @@ class Swarm: - def __init__(self, client=None): - if not client: - client = OpenAI() - self.client = client + def __init__(self, client: Optional[OpenAI]=None): + self.client = client or OpenAI() def get_chat_completion( self, diff --git a/swarm/repl/repl.py b/swarm/repl/repl.py index 50c3d0ee..90fa043e 100644 --- a/swarm/repl/repl.py +++ b/swarm/repl/repl.py @@ -1,6 +1,8 @@ import json +from typing import Any, Dict, List, Optional from swarm import Swarm +from ..types import Agent def process_and_print_streaming_response(response): @@ -34,7 +36,7 @@ def process_and_print_streaming_response(response): return chunk["response"] -def pretty_print_messages(messages) -> None: +def pretty_print_messages(messages: List[Dict[str, Any]]) -> None: for message in messages: if message["role"] != "assistant": continue @@ -58,7 +60,10 @@ def pretty_print_messages(messages) -> None: def run_demo_loop( - starting_agent, context_variables=None, stream=False, debug=False + starting_agent: Agent, + context_variables: Optional[Dict[str, Any]] = None, + stream=False, + debug=False ) -> None: client = Swarm() print("Starting Swarm CLI 🐝") diff --git a/swarm/types.py b/swarm/types.py index 0099abbd..fe6d94a1 100644 --- a/swarm/types.py +++ b/swarm/types.py @@ -3,7 +3,7 @@ ChatCompletionMessageToolCall, Function, ) -from typing import List, Callable, Union, Optional +from typing import List, Callable, Union, Optional, Any # Third-party imports from pydantic import BaseModel @@ -14,14 +14,14 @@ class Agent(BaseModel): name: str = "Agent" model: str = "gpt-4o" - instructions: Union[str, Callable[[], str]] = "You are a helpful agent." + instructions: Union[str, Callable[[dict[str, Any]], str]] = "You are a helpful agent." functions: List[AgentFunction] = [] tool_choice: str = None parallel_tool_calls: bool = True class Response(BaseModel): - messages: List = [] + messages: List[dict[str, Any]] = [] agent: Optional[Agent] = None context_variables: dict = {} diff --git a/swarm/util.py b/swarm/util.py index 520c8da2..8b51bd10 100644 --- a/swarm/util.py +++ b/swarm/util.py @@ -1,5 +1,6 @@ import inspect from datetime import datetime +from typing import Any, Callable, Dict, List def debug_print(debug: bool, *args: str) -> None: @@ -10,7 +11,7 @@ def debug_print(debug: bool, *args: str) -> None: print(f"\033[97m[\033[90m{timestamp}\033[97m]\033[90m {message}\033[0m") -def merge_fields(target, source): +def merge_fields(target: Dict[str, Any], source: Dict[str, Any]) -> None: for key, value in source.items(): if isinstance(value, str): target[key] += value @@ -18,7 +19,7 @@ def merge_fields(target, source): merge_fields(target[key], value) -def merge_chunk(final_response: dict, delta: dict) -> None: +def merge_chunk(final_response: Dict[str, Any], delta: Dict[str, Any]) -> None: delta.pop("role", None) merge_fields(final_response, delta) @@ -28,7 +29,7 @@ def merge_chunk(final_response: dict, delta: dict) -> None: merge_fields(final_response["tool_calls"][index], tool_calls[0]) -def function_to_json(func) -> dict: +def function_to_json(func: Callable) -> Dict[str, Any]: """ Converts a Python function into a JSON-serializable dictionary that describes the function's signature, including its name,