Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functions support #14

Merged
merged 2 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions examples/example_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Пример - использование функций.
Для работы поисковой системы используется утилита из пакета gigachain.
Установите соответствующую библиотеку с помощью команды

pip install -U gigachain duckduckgo-search

"""

import json

from gigachat.models import Chat, Function, FunctionParameters, Messages, MessagesRole, FunctionCall
from gigachat import GigaChat

from langchain_community.tools.ddg_search.tool import DuckDuckGoSearchRun


def search_ddg(search_query):
"""Поиск в DuckDuckGo. Полезен, когда нужно ответить на вопросы о текущих событиях. Входными данными должен быть поисковый запрос."""
return DuckDuckGoSearchRun().run(search_query)

# Используйте токен, полученный в личном кабинете из поля Авторизационные данные
with GigaChat(
credentials=...,
model=... # Model with functions
) as giga:
search = Function(
name="duckduckgo_search",
description="""Поиск в DuckDuckGo.
Полезен, когда нужно ответить на вопросы о текущих событиях.
Входными данными должен быть поисковый запрос.""",
parameters=FunctionParameters(
type="object",
properties={"query": {"type": "string", "description": "Поисковый запрос"}},
required=["query"],
),
)

messages = []
function_called = False
while True:
# Если предыдущий ответ LLM не был вызовом функции - просим пользователя продолжить диалог
if not function_called:
query = input("\033[92mUser: \033[0m")
messages.append(Messages(role=MessagesRole.USER, content=query))

chat = Chat(messages=messages, functions=[search])

resp = giga.chat(chat).choices[0]
mess = resp.message
messages.append(mess)

print("\033[93m" + f"Bot: \033[0m{mess.content}")

function_called = False
func_result = ""
if resp.finish_reason == "function_call":
print("\033[90m" + f" >> Processing function call {mess.function_call}" + "\033[0m")
if mess.function_call.name == "duckduckgo_search":
query = mess.function_call.arguments.get("query", None)
if query:
func_result = search_ddg(query)
print("\033[90m" + f" << Function result: {func_result}\n\n" + "\033[0m")

messages.append(
Messages(role=MessagesRole.FUNCTION,
content=json.dumps({"result": func_result}, ensure_ascii=False))
)
function_called = True
182 changes: 42 additions & 140 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "gigachat"
version = "0.1.13"
version = "0.1.14"
description = "GigaChat. Python-library for GigaChain and LangChain"
authors = ["Konstantin Krestnikov <[email protected]>", "Sergey Malyshev <[email protected]>"]
license = "MIT"
Expand All @@ -10,7 +10,7 @@ packages = [{include = "gigachat", from = "src"}]

[tool.poetry.dependencies]
python = "^3.8"
pydantic = ">=1,<3"
pydantic = ">=1,<2"
httpx = "<1"

[tool.poetry.group.dev.dependencies]
Expand Down
6 changes: 6 additions & 0 deletions src/gigachat/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from gigachat.models.embedding import Embedding
from gigachat.models.embeddings import Embeddings
from gigachat.models.embeddings_usage import EmbeddingsUsage
from gigachat.models.function import Function
from gigachat.models.function_call import FunctionCall
from gigachat.models.function_parameters import FunctionParameters
from gigachat.models.messages import Messages
from gigachat.models.messages_chunk import MessagesChunk
from gigachat.models.messages_role import MessagesRole
Expand All @@ -26,6 +29,9 @@
"Embedding",
"Embeddings",
"EmbeddingsUsage",
"Function",
"FunctionCall",
"FunctionParameters",
"Messages",
"MessagesChunk",
"MessagesRole",
Expand Down
5 changes: 5 additions & 0 deletions src/gigachat/models/chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional

from gigachat.models.function import Function
from gigachat.models.messages import Messages
from gigachat.pydantic_v1 import BaseModel

Expand Down Expand Up @@ -27,3 +28,7 @@ class Chat(BaseModel):
"""Интервал в секундах между отправкой токенов в потоке"""
profanity_check: Optional[bool] = None
"""Параметр цензуры"""
function_call: Optional[str] = None
"""Правила вызова функций"""
functions: Optional[List[Function]] = None
"""Набор функций, которые могут быть вызваны моделью"""
17 changes: 17 additions & 0 deletions src/gigachat/models/function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Any, Dict, Optional

from gigachat.models.function_parameters import FunctionParameters
from gigachat.pydantic_v1 import BaseModel


class Function(BaseModel):
"""Функция, которая может быть вызвана моделью"""

name: str
"""Название функции"""
description: Optional[str] = None
"""Описание функции"""
parameters: Optional[FunctionParameters] = None
"""Список параметров функции"""
return_parameters: Optional[Dict[Any, Any]] = None
"""Список возвращаемых параметров функции"""
12 changes: 12 additions & 0 deletions src/gigachat/models/function_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Any, Dict, Optional

from gigachat.pydantic_v1 import BaseModel


class FunctionCall(BaseModel):
"""Вызов функции"""

name: str
"""Название функции"""
arguments: Optional[Dict[Any, Any]] = None
"""Описание функции"""
14 changes: 14 additions & 0 deletions src/gigachat/models/function_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Any, Dict, List, Optional

from gigachat.pydantic_v1 import BaseModel, Field


class FunctionParameters(BaseModel):
"""Функция, которая может быть вызвана моделью"""

_type: str = Field(default="obect", alias="type")
"""Тип параметров функции"""
properties: Optional[Dict[Any, Any]] = None
"""Описание функции"""
required: Optional[List[str]] = None
"""Список обязательных параметров"""
7 changes: 6 additions & 1 deletion src/gigachat/models/messages.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Optional

from gigachat.models.function_call import FunctionCall
from gigachat.models.messages_role import MessagesRole
from gigachat.pydantic_v1 import BaseModel

Expand All @@ -7,8 +10,10 @@ class Messages(BaseModel):

role: MessagesRole
"""Роль автора сообщения"""
content: str
content: str = ""
"""Текст сообщения"""
function_call: Optional[FunctionCall] = None
"""Вызов функции"""

class Config:
use_enum_values = True
1 change: 1 addition & 0 deletions src/gigachat/models/messages_role.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ class MessagesRole(str, Enum):
ASSISTANT = "assistant"
SYSTEM = "system"
USER = "user"
FUNCTION = "function"
28 changes: 28 additions & 0 deletions tests/data/chat_completion_function.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"choices": [
{
"message": {
"content": "",
"role": "assistant",
"function_call": {
"name": "fc",
"arguments": {
"location": "Москва",
"num_days": 0
}
}
},
"index": 0,
"finish_reason": "function_call"
}
],
"created": 1706016586,
"model": "GigaChat-funcs",
"object": "chat.completion",
"usage": {
"prompt_tokens": 138,
"completion_tokens": 21,
"total_tokens": 159,
"system_tokens": 0
}
}
64 changes: 64 additions & 0 deletions tests/data/chat_function.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
{
"messages": [
{
"content": "Какая погода в Москве сегодня?",
"role": "user"
}
],
"functions": [
{
"name": "fc",
"description": "Get an N-day weather forecast",
"parameters": {
"properties": {
"format": {
"description": "The temperature unit to use. Infer this from the users location.",
"enum": [
"celsius",
"fahrenheit"
],
"type": "string"
},
"location": {
"description": "Location, e.g. the city name",
"type": "string"
},
"num_days": {
"description": "The number of days to forecast",
"type": "integer"
}
},
"required": [
"location",
"num_days"
],
"type": "object"
},
"return_parameters": {
"properties": {
"error": {
"description": "returned if an error has occured, value is the error description string",
"type": "string"
},
"forecast": {
"description": "Weather condition descriptions",
"items": {
"type": "string"
},
"type": "array"
},
"location": {
"description": "Location, e.g. the city name",
"type": "string"
},
"temperature": {
"description": "The temperature forecast for the location",
"type": "integer"
}
},
"type": "object"
}
}
],
"model": "GigaChat-funcs"
}
17 changes: 17 additions & 0 deletions tests/unit_tests/gigachat/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
ACCESS_TOKEN = get_json("access_token.json")
TOKEN = get_json("token.json")
CHAT = Chat.parse_obj(get_json("chat.json"))
CHAT_FUNCTION = Chat.parse_obj(get_json("chat_function.json"))
CHAT_COMPLETION = get_json("chat_completion.json")
CHAT_COMPLETION_FUNCTION = get_json("chat_completion_function.json")
CHAT_COMPLETION_STREAM = get_bytes("chat_completion.stream")
EMBEDDINGS = get_json("embeddings.json")
MODELS = get_json("models.json")
Expand Down Expand Up @@ -249,6 +251,21 @@ def test_chat_update_token_error(httpx_mock: HTTPXMock) -> None:
assert client.token != access_token


def test_chat_with_functions(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=CHAT_URL, json=CHAT_COMPLETION_FUNCTION)
access_token = "access_token"

with GigaChatSyncClient(base_url=BASE_URL, access_token=access_token) as client:
response = client.chat(CHAT_FUNCTION)

assert isinstance(response, ChatCompletion)
assert response.choices[0].finish_reason == "function_call"
assert response.choices[0].message.function_call is not None
assert response.choices[0].message.function_call.name == "fc"
assert response.choices[0].message.function_call.arguments is not None
assert response.choices[0].message.function_call.arguments == {"location": "Москва", "num_days": 0}


def test_embeddings(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=EMBEDDINGS_URL, json=EMBEDDINGS)

Expand Down