Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding type hints to functions #44

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions swarm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import copy
import json
from collections import defaultdict
from typing import List, Callable, Union
from typing import List, Callable, Union, Optional

# Package/library imports
from openai import OpenAI
Expand All @@ -24,10 +24,8 @@


class Swarm:
def __init__(self, client=None):
if not client:
client = OpenAI()
self.client = client
def __init__(self, client: Optional[OpenAI]=None):
self.client = client or OpenAI()

def get_chat_completion(
self,
Expand Down
9 changes: 7 additions & 2 deletions swarm/repl/repl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from typing import Any, Dict, List, Optional

from swarm import Swarm
from ..types import Agent


def process_and_print_streaming_response(response):
Expand Down Expand Up @@ -34,7 +36,7 @@ def process_and_print_streaming_response(response):
return chunk["response"]


def pretty_print_messages(messages) -> None:
def pretty_print_messages(messages: List[Dict[str, Any]]) -> None:
for message in messages:
if message["role"] != "assistant":
continue
Expand All @@ -58,7 +60,10 @@ def pretty_print_messages(messages) -> None:


def run_demo_loop(
starting_agent, context_variables=None, stream=False, debug=False
starting_agent: Agent,
context_variables: Optional[Dict[str, Any]] = None,
stream=False,
debug=False
) -> None:
client = Swarm()
print("Starting Swarm CLI 🐝")
Expand Down
6 changes: 3 additions & 3 deletions swarm/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
ChatCompletionMessageToolCall,
Function,
)
from typing import List, Callable, Union, Optional
from typing import List, Callable, Union, Optional, Any

# Third-party imports
from pydantic import BaseModel
Expand All @@ -14,14 +14,14 @@
class Agent(BaseModel):
name: str = "Agent"
model: str = "gpt-4o"
instructions: Union[str, Callable[[], str]] = "You are a helpful agent."
instructions: Union[str, Callable[[dict[str, Any]], str]] = "You are a helpful agent."
functions: List[AgentFunction] = []
tool_choice: str = None
parallel_tool_calls: bool = True


class Response(BaseModel):
messages: List = []
messages: List[dict[str, Any]] = []
agent: Optional[Agent] = None
context_variables: dict = {}

Expand Down
7 changes: 4 additions & 3 deletions swarm/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
from datetime import datetime
from typing import Any, Callable, Dict, List


def debug_print(debug: bool, *args: str) -> None:
Expand All @@ -10,15 +11,15 @@ def debug_print(debug: bool, *args: str) -> None:
print(f"\033[97m[\033[90m{timestamp}\033[97m]\033[90m {message}\033[0m")


def merge_fields(target, source):
def merge_fields(target: Dict[str, Any], source: Dict[str, Any]) -> None:
for key, value in source.items():
if isinstance(value, str):
target[key] += value
elif value is not None and isinstance(value, dict):
merge_fields(target[key], value)


def merge_chunk(final_response: dict, delta: dict) -> None:
def merge_chunk(final_response: Dict[str, Any], delta: Dict[str, Any]) -> None:
delta.pop("role", None)
merge_fields(final_response, delta)

Expand All @@ -28,7 +29,7 @@ def merge_chunk(final_response: dict, delta: dict) -> None:
merge_fields(final_response["tool_calls"][index], tool_calls[0])


def function_to_json(func) -> dict:
def function_to_json(func: Callable) -> Dict[str, Any]:
"""
Converts a Python function into a JSON-serializable dictionary
that describes the function's signature, including its name,
Expand Down