diff --git a/docs/sdk/llm.md b/docs/sdk/llm.md
new file mode 100644
index 000000000..d4548f464
--- /dev/null
+++ b/docs/sdk/llm.md
@@ -0,0 +1,3 @@
+# LLM module
+
+::: kili.llm.presentation.client.llm.LlmClientMethods
diff --git a/docs/sdk/tutorials/llm_project_setup.md b/docs/sdk/tutorials/llm_project_setup.md
new file mode 100644
index 000000000..e6dae9b85
--- /dev/null
+++ b/docs/sdk/tutorials/llm_project_setup.md
@@ -0,0 +1,165 @@
+
+
+
+# How to Set Up a Kili Project with a LLM Model and Create a Conversation
+
+In this tutorial, you'll learn how to set up a project in Kili Technology that integrates a Large Language Model (LLM), associate the LLM with your project, and create a conversation using the Kili Python SDK. By the end of this guide, you'll have a functional project ready to collect and label LLM outputs for comparison and evaluation.
+
+
+Here are the steps we will follow:
+
+1. Creating a Kili project with a custom interface
+2. Creating an LLM model
+3. Associating the model with the project
+4. Creating a conversation
+
+## Creating a Kili Project with a Custom Interface
+
+We will create a Kili project with a custom interface that includes a comparison job and a classification job. This interface will be used for labeling and comparing LLM outputs.
+
+Here's the JSON interface we will use:
+
+
+```python
+interface = {
+ "jobs": {
+ "COMPARISON_JOB": {
+ "content": {
+ "options": {
+ "IS_MUCH_BETTER": {"children": [], "name": "Is much better", "id": "option1"},
+ "IS_BETTER": {"children": [], "name": "Is better", "id": "option2"},
+ "IS_SLIGHTLY_BETTER": {
+ "children": [],
+ "name": "Is slightly better",
+ "id": "option3",
+ },
+ "TIE": {"children": [], "name": "Tie", "id": "option4", "mutual": True},
+ },
+ "input": "radio",
+ },
+ "instruction": "Pick the best answer",
+ "mlTask": "COMPARISON",
+ "required": 1,
+ "isChild": False,
+ "isNew": False,
+ },
+ "CLASSIFICATION_JOB": {
+ "content": {
+ "categories": {
+ "BOTH_ARE_GOOD": {"children": [], "name": "Both are good", "id": "category1"},
+ "BOTH_ARE_BAD": {"children": [], "name": "Both are bad", "id": "category2"},
+ },
+ "input": "radio",
+ },
+ "instruction": "Overall quality",
+ "mlTask": "CLASSIFICATION",
+ "required": 0,
+ "isChild": False,
+ "isNew": False,
+ },
+ }
+}
+```
+
+Now, we create the project using the `create_project` method, with type `LLM_INSTR_FOLLOWING`:
+
+
+```python
+from kili.client import Kili
+
+kili = Kili(
+ # api_endpoint="https://cloud.kili-technology.com/api/label/v2/graphql",
+)
+project = kili.create_project(
+ title="[Kili SDK Notebook]: LLM Project",
+ description="Project Description",
+ input_type="LLM_INSTR_FOLLOWING",
+ json_interface=interface,
+)
+project_id = project["id"]
+```
+
+## Creating an LLM Model
+
+We will now create an LLM model in Kili, by specifying the model's credentials and connector type. In this example, we will use the OpenAI SDK as the connector type.
+
+**Note**: Replace `api_key` and `endpoint` with your model's actual credentials.
+
+
+```python
+model_response = kili.llm.create_model(
+ organization_id="",
+ model={
+ "credentials": {
+ "api_key": "",
+ "endpoint": "https://api.openai.com/v1/",
+ },
+ "name": "My Model",
+ "type": "OPEN_AI_SDK",
+ },
+)
+
+model_id = model_response["id"]
+```
+
+You can now see the model integration by clicking **Manage organization** :
+
+![Model Integration]()
+
+## Associating the Model with the Project
+
+Next, we will associate the created model with our project by creating project models with different configurations. Each time you create a prompt, two models will be chosen from the project models in the project
+
+In this example, we compare **GPT 4o** and **GPT 4o Mini**, with different temperature settings :
+
+
+```python
+# First project model with a fixed temperature
+first_project_model = kili.llm.create_project_model(
+ project_id=project_id,
+ model_id=model_id,
+ configuration={
+ "model": "gpt-4o",
+ "temperature": 0.5,
+ },
+)
+
+# Second project model with a temperature range
+second_project_model = kili.llm.create_project_model(
+ project_id=project_id,
+ model_id=model_id,
+ configuration={
+ "model": "gpt-4o-mini",
+ "temperature": {"min": 0.2, "max": 0.8},
+ },
+)
+```
+
+You can now see the project models in the project settings :
+
+![Project Models]()
+
+## Creating a Conversation
+
+Now, we'll generate a conversation by providing a prompt.
+
+
+
+```python
+conversation = kili.llm.create_conversation(
+ project_id=project_id, prompt="Give me Schrödinger equation."
+)
+```
+
+It will add an asset to your project, and you'll be ready to start labeling the conversation :
+
+![Conversation]()
+
+## Summary
+
+In this tutorial, we've:
+
+- **Created a Kili project** with a custom interface for LLM output comparison.
+- **Registered an LLM model** in Kili with the necessary credentials.
+- **Associated the model** with the project by creating project models with different configurations.
+- **Generated a conversation** using a prompt, adding it to the project for labeling.
diff --git a/docs/tutorials.md b/docs/tutorials.md
index 0a4705d8f..5e944208f 100644
--- a/docs/tutorials.md
+++ b/docs/tutorials.md
@@ -71,6 +71,11 @@ For a more specific use case, follow [this tutorial](https://python-sdk-docs.kil
Webhooks are really similar to plugins, except they are self-hosted, and require a web service deployed at your end, callable by Kili. To learn how to use webhooks, follow [this tutorial](https://python-sdk-docs.kili-technology.com/latest/sdk/tutorials/webhooks_example/).
+## LLM
+
+[This tutorial](https://python-sdk-docs.kili-technology.com/latest/sdk/tutorials/llm_project_setup/) will show you how to set up a Kili project that uses a Large Language Model (LLM), create and associate the LLM model with the project, and initiate a conversation using the Kili Python SDK.
+
+
## Integrations
[This tutorial](https://python-sdk-docs.kili-technology.com/latest/sdk/tutorials/vertex_ai_automl_od/) will show you how train an object detection model with Vertex AI AutoML and Kili for faster annotation
diff --git a/mkdocs.yml b/mkdocs.yml
index a452c49fe..673174172 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -20,6 +20,7 @@ nav:
- Label: sdk/label.md
- Label Utils: sdk/label_utils.md
- Label Parsing: sdk/label_parsing.md
+ - LLM: sdk/llm.md
- Notification: sdk/notification.md
- Organization: sdk/organization.md
- Plugins: sdk/plugins.md
@@ -57,6 +58,7 @@ nav:
- Exporting Project Data:
- Exporting a Project: sdk/tutorials/export_a_kili_project.md
- Parsing Labels: sdk/tutorials/label_parsing.md
+ - LLM Projects: sdk/tutorials/llm_project_setup.md
- Setting Up Plugins:
- Developing Plugins: sdk/tutorials/plugins_development.md
- Plugin Example - Programmatic QA: sdk/tutorials/plugins_example.md
diff --git a/recipes/img/llm_conversation.png b/recipes/img/llm_conversation.png
new file mode 100644
index 000000000..4f4db1034
Binary files /dev/null and b/recipes/img/llm_conversation.png differ
diff --git a/recipes/img/llm_models.png b/recipes/img/llm_models.png
new file mode 100644
index 000000000..d63e51759
Binary files /dev/null and b/recipes/img/llm_models.png differ
diff --git a/recipes/img/llm_project_models.png b/recipes/img/llm_project_models.png
new file mode 100644
index 000000000..0de6b8cba
Binary files /dev/null and b/recipes/img/llm_project_models.png differ
diff --git a/recipes/llm_project_setup.ipynb b/recipes/llm_project_setup.ipynb
new file mode 100644
index 000000000..aea73f3d5
--- /dev/null
+++ b/recipes/llm_project_setup.ipynb
@@ -0,0 +1,280 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# How to Set Up a Kili Project with a LLM Model and Create a Conversation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this tutorial, you'll learn how to set up a project in Kili Technology that integrates a Large Language Model (LLM), associate the LLM with your project, and create a conversation using the Kili Python SDK. By the end of this guide, you'll have a functional project ready to collect and label LLM outputs for comparison and evaluation.\n",
+ "\n",
+ "\n",
+ "Here are the steps we will follow:\n",
+ "\n",
+ "1. Creating a Kili project with a custom interface\n",
+ "2. Creating an LLM model\n",
+ "3. Associating the model with the project\n",
+ "4. Creating a conversation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Creating a Kili Project with a Custom Interface"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We will create a Kili project with a custom interface that includes a comparison job and a classification job. This interface will be used for labeling and comparing LLM outputs.\n",
+ "\n",
+ "Here's the JSON interface we will use:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "interface = {\n",
+ " \"jobs\": {\n",
+ " \"COMPARISON_JOB\": {\n",
+ " \"content\": {\n",
+ " \"options\": {\n",
+ " \"IS_MUCH_BETTER\": {\"children\": [], \"name\": \"Is much better\", \"id\": \"option1\"},\n",
+ " \"IS_BETTER\": {\"children\": [], \"name\": \"Is better\", \"id\": \"option2\"},\n",
+ " \"IS_SLIGHTLY_BETTER\": {\n",
+ " \"children\": [],\n",
+ " \"name\": \"Is slightly better\",\n",
+ " \"id\": \"option3\",\n",
+ " },\n",
+ " \"TIE\": {\"children\": [], \"name\": \"Tie\", \"id\": \"option4\", \"mutual\": True},\n",
+ " },\n",
+ " \"input\": \"radio\",\n",
+ " },\n",
+ " \"instruction\": \"Pick the best answer\",\n",
+ " \"mlTask\": \"COMPARISON\",\n",
+ " \"required\": 1,\n",
+ " \"isChild\": False,\n",
+ " \"isNew\": False,\n",
+ " },\n",
+ " \"CLASSIFICATION_JOB\": {\n",
+ " \"content\": {\n",
+ " \"categories\": {\n",
+ " \"BOTH_ARE_GOOD\": {\"children\": [], \"name\": \"Both are good\", \"id\": \"category1\"},\n",
+ " \"BOTH_ARE_BAD\": {\"children\": [], \"name\": \"Both are bad\", \"id\": \"category2\"},\n",
+ " },\n",
+ " \"input\": \"radio\",\n",
+ " },\n",
+ " \"instruction\": \"Overall quality\",\n",
+ " \"mlTask\": \"CLASSIFICATION\",\n",
+ " \"required\": 0,\n",
+ " \"isChild\": False,\n",
+ " \"isNew\": False,\n",
+ " },\n",
+ " }\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we create the project using the `create_project` method, with type `LLM_INSTR_FOLLOWING`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from kili.client import Kili\n",
+ "\n",
+ "kili = Kili(\n",
+ " # api_endpoint=\"https://cloud.kili-technology.com/api/label/v2/graphql\",\n",
+ ")\n",
+ "project = kili.create_project(\n",
+ " title=\"[Kili SDK Notebook]: LLM Project\",\n",
+ " description=\"Project Description\",\n",
+ " input_type=\"LLM_INSTR_FOLLOWING\",\n",
+ " json_interface=interface,\n",
+ ")\n",
+ "project_id = project[\"id\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Creating an LLM Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We will now create an LLM model in Kili, by specifying the model's credentials and connector type. In this example, we will use the OpenAI SDK as the connector type.\n",
+ "\n",
+ "**Note**: Replace `api_key` and `endpoint` with your model's actual credentials."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_response = kili.llm.create_model(\n",
+ " organization_id=\"\",\n",
+ " model={\n",
+ " \"credentials\": {\n",
+ " \"api_key\": \"\",\n",
+ " \"endpoint\": \"https://api.openai.com/v1/\",\n",
+ " },\n",
+ " \"name\": \"My Model\",\n",
+ " \"type\": \"OPEN_AI_SDK\",\n",
+ " },\n",
+ ")\n",
+ "\n",
+ "model_id = model_response[\"id\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You can now see the model integration by clicking **Manage organization** :\n",
+ "\n",
+ "![Model Integration](./img/llm_models.png)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Associating the Model with the Project"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we will associate the created model with our project by creating project models with different configurations. Each time you create a prompt, two models will be chosen from the project models in the project \n",
+ "\n",
+ "In this example, we compare **GPT 4o** and **GPT 4o Mini**, with different temperature settings :"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# First project model with a fixed temperature\n",
+ "first_project_model = kili.llm.create_project_model(\n",
+ " project_id=project_id,\n",
+ " model_id=model_id,\n",
+ " configuration={\n",
+ " \"model\": \"gpt-4o\",\n",
+ " \"temperature\": 0.5,\n",
+ " },\n",
+ ")\n",
+ "\n",
+ "# Second project model with a temperature range\n",
+ "second_project_model = kili.llm.create_project_model(\n",
+ " project_id=project_id,\n",
+ " model_id=model_id,\n",
+ " configuration={\n",
+ " \"model\": \"gpt-4o-mini\",\n",
+ " \"temperature\": {\"min\": 0.2, \"max\": 0.8},\n",
+ " },\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You can now see the project models in the project settings :\n",
+ "\n",
+ "![Project Models](./img/llm_project_models.png)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Creating a Conversation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we'll generate a conversation by providing a prompt.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "conversation = kili.llm.create_conversation(\n",
+ " project_id=project_id, prompt=\"Give me Schrödinger equation.\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "It will add an asset to your project, and you'll be ready to start labeling the conversation :\n",
+ "\n",
+ "![Conversation](./img/llm_conversation.png)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Summary"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this tutorial, we've:\n",
+ "\n",
+ "- **Created a Kili project** with a custom interface for LLM output comparison.\n",
+ "- **Registered an LLM model** in Kili with the necessary credentials.\n",
+ "- **Associated the model** with the project by creating project models with different configurations.\n",
+ "- **Generated a conversation** using a prompt, adding it to the project for labeling.\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/src/kili/adapters/kili_api_gateway/llm/mappers.py b/src/kili/adapters/kili_api_gateway/llm/mappers.py
index cf67c2691..6d0019365 100644
--- a/src/kili/adapters/kili_api_gateway/llm/mappers.py
+++ b/src/kili/adapters/kili_api_gateway/llm/mappers.py
@@ -109,3 +109,30 @@ def map_delete_project_model_input(project_model_id: str) -> Dict:
return {
"deleteProjectModelId": project_model_id,
}
+
+
+def map_create_llm_asset_input(data: Dict) -> Dict:
+ """Map the input for the createLLMAsset mutation."""
+ result = {
+ "authorId": data["author_id"],
+ }
+ if "status" in data:
+ result["status"] = data["status"]
+ if "label_type" in data:
+ result["labelType"] = data["label_type"]
+ return result
+
+
+def map_project_where(project_id: str) -> Dict:
+ """Map the 'where' parameter for mutations that require a ProjectWhere."""
+ return {"id": project_id}
+
+
+def map_create_chat_item_input(label_id: str, prompt: str) -> Dict:
+ """Map the input for the createChatItem mutation."""
+ return {"content": prompt, "role": "USER", "labelId": label_id}
+
+
+def map_asset_where(asset_id: str) -> Dict:
+ """Map the 'where' parameter for the createChatItem mutation."""
+ return {"id": asset_id}
diff --git a/src/kili/adapters/kili_api_gateway/llm/operations.py b/src/kili/adapters/kili_api_gateway/llm/operations.py
index b1e421164..c185fe55d 100644
--- a/src/kili/adapters/kili_api_gateway/llm/operations.py
+++ b/src/kili/adapters/kili_api_gateway/llm/operations.py
@@ -94,3 +94,25 @@ def get_project_models_query(fragment: str) -> str:
}}
}}
"""
+
+
+def get_create_llm_asset_mutation(fragment: str) -> str:
+ """Return the GraphQL createLLMAsset mutation."""
+ return f"""
+ mutation CreateLLMAsset($where: ProjectWhere!, $data: CreateLLMAssetData!) {{
+ createLLMAsset(where: $where, data: $data) {{
+ {fragment}
+ }}
+ }}
+ """
+
+
+def get_create_chat_item_mutation(fragment: str) -> str:
+ """Return the GraphQL createChatItem mutation."""
+ return f"""
+ mutation CreateChatItem($data: CreateChatItemData!, $where: AssetWhere!) {{
+ createChatItem(data: $data, where: $where) {{
+ {fragment}
+ }}
+ }}
+ """
diff --git a/src/kili/adapters/kili_api_gateway/llm/operations_mixin.py b/src/kili/adapters/kili_api_gateway/llm/operations_mixin.py
index 5a7483788..ac64012ba 100644
--- a/src/kili/adapters/kili_api_gateway/llm/operations_mixin.py
+++ b/src/kili/adapters/kili_api_gateway/llm/operations_mixin.py
@@ -1,6 +1,6 @@
"""Mixin extending Kili API Gateway class with Api Keys related operations."""
-from typing import Dict, Generator, Optional
+from typing import Dict, List, Optional, cast
from kili.adapters.kili_api_gateway.base import BaseOperationMixin
from kili.adapters.kili_api_gateway.helpers.queries import (
@@ -9,16 +9,22 @@
fragment_builder,
)
from kili.adapters.kili_api_gateway.llm.mappers import (
+ map_asset_where,
+ map_create_chat_item_input,
+ map_create_llm_asset_input,
map_create_model_input,
map_create_project_model_input,
map_delete_model_input,
map_delete_project_model_input,
+ map_project_where,
map_update_model_input,
map_update_project_model_input,
model_where_wrapper,
project_model_where_mapper,
)
from kili.adapters.kili_api_gateway.llm.operations import (
+ get_create_chat_item_mutation,
+ get_create_llm_asset_mutation,
get_create_model_mutation,
get_create_project_model_mutation,
get_delete_model_mutation,
@@ -30,15 +36,28 @@
get_update_project_model_mutation,
)
from kili.domain.llm import (
+ ChatItemDict,
+ ModelDict,
ModelToCreateInput,
ModelToUpdateInput,
OrganizationModelFilters,
+ ProjectModelDict,
ProjectModelFilters,
ProjectModelToCreateInput,
ProjectModelToUpdateInput,
)
from kili.domain.types import ListOrTuple
+DEFAULT_PROJECT_FIELDS = ["id", "name", "credentials", "type"]
+DEFAULT_PROJECT_MODEL_FIELDS = [
+ "id",
+ "configuration",
+ "model.id",
+ "model.credentials",
+ "model.name",
+ "model.type",
+]
+
class ModelConfigurationOperationMixin(BaseOperationMixin):
"""Mixin extending Kili API Gateway class with model configuration related operations."""
@@ -46,74 +65,79 @@ class ModelConfigurationOperationMixin(BaseOperationMixin):
def list_models(
self,
filters: OrganizationModelFilters,
- fields: ListOrTuple[str],
+ fields: Optional[ListOrTuple[str]] = None,
options: Optional[QueryOptions] = None,
- ) -> Generator[Dict, None, None]:
+ ) -> List[ModelDict]:
"""List models with given options."""
- fragment = fragment_builder(fields)
+ fragment = fragment_builder(fields or DEFAULT_PROJECT_FIELDS)
query = get_models_query(fragment)
where = model_where_wrapper(filters)
- return PaginatedGraphQLQuery(self.graphql_client).execute_query_from_paginated_call(
- query,
- where,
- options if options else QueryOptions(disable_tqdm=False),
- "Retrieving organization models",
- None,
- )
-
- def get_model(self, model_id: str, fields: ListOrTuple[str]) -> Dict:
+ return [
+ cast(ModelDict, item)
+ for item in PaginatedGraphQLQuery(
+ self.graphql_client
+ ).execute_query_from_paginated_call(
+ query,
+ where,
+ options if options else QueryOptions(disable_tqdm=False),
+ "Retrieving organization models",
+ None,
+ )
+ ]
+
+ def get_model(self, model_id: str, fields: Optional[ListOrTuple[str]] = None) -> ModelDict:
"""Get a model by ID."""
- fragment = fragment_builder(fields)
+ fragment = fragment_builder(fields or DEFAULT_PROJECT_FIELDS)
query = get_model_query(fragment)
variables = {"modelId": model_id}
result = self.graphql_client.execute(query, variables)
return result["model"]
- def create_model(self, model: ModelToCreateInput) -> Dict:
+ def create_model(self, model: ModelToCreateInput) -> ModelDict:
"""Send a GraphQL request calling createModel resolver."""
payload = {"input": map_create_model_input(model)}
- fragment = fragment_builder(["id"])
+ fragment = fragment_builder(DEFAULT_PROJECT_FIELDS)
mutation = get_create_model_mutation(fragment)
result = self.graphql_client.execute(mutation, payload)
return result["createModel"]
- def update_properties_in_model(self, model_id: str, model: ModelToUpdateInput) -> Dict:
+ def update_properties_in_model(self, model_id: str, model: ModelToUpdateInput) -> ModelDict:
"""Send a GraphQL request calling updateModel resolver."""
payload = {"id": model_id, "input": map_update_model_input(model)}
- fragment = fragment_builder(["id"])
+ fragment = fragment_builder(DEFAULT_PROJECT_FIELDS)
mutation = get_update_model_mutation(fragment)
result = self.graphql_client.execute(mutation, payload)
return result["updateModel"]
- def delete_model(self, model_id: str) -> Dict:
+ def delete_model(self, model_id: str) -> bool:
"""Send a GraphQL request to delete an organization model."""
payload = map_delete_model_input(model_id)
mutation = get_delete_model_mutation()
result = self.graphql_client.execute(mutation, payload)
return result["deleteModel"]
- def create_project_model(self, project_model: ProjectModelToCreateInput) -> Dict:
+ def create_project_model(self, project_model: ProjectModelToCreateInput) -> ProjectModelDict:
"""Send a GraphQL request calling createModel resolver."""
payload = {"input": map_create_project_model_input(project_model)}
- fragment = fragment_builder(["id"])
+ fragment = fragment_builder(DEFAULT_PROJECT_MODEL_FIELDS)
mutation = get_create_project_model_mutation(fragment)
result = self.graphql_client.execute(mutation, payload)
return result["createProjectModel"]
def update_project_model(
self, project_model_id: str, project_model: ProjectModelToUpdateInput
- ) -> Dict:
+ ) -> ProjectModelDict:
"""Send a GraphQL request calling updateProjectModel resolver."""
payload = {
"updateProjectModelId": project_model_id,
"input": map_update_project_model_input(project_model),
}
- fragment = fragment_builder(["id", "configuration"])
+ fragment = fragment_builder(DEFAULT_PROJECT_MODEL_FIELDS)
mutation = get_update_project_model_mutation(fragment)
result = self.graphql_client.execute(mutation, payload)
return result["updateProjectModel"]
- def delete_project_model(self, project_model_id: str) -> Dict:
+ def delete_project_model(self, project_model_id: str) -> bool:
"""Send a GraphQL request to delete a project model."""
payload = map_delete_project_model_input(project_model_id)
mutation = get_delete_project_model_mutation()
@@ -123,17 +147,49 @@ def delete_project_model(self, project_model_id: str) -> Dict:
def list_project_models(
self,
filters: ProjectModelFilters,
- fields: ListOrTuple[str],
+ fields: Optional[ListOrTuple[str]] = None,
options: Optional[QueryOptions] = None,
- ) -> Generator[Dict, None, None]:
+ ) -> List[ProjectModelDict]:
"""List project models with given options."""
- fragment = fragment_builder(fields)
+ fragment = fragment_builder(fields or DEFAULT_PROJECT_MODEL_FIELDS)
query = get_project_models_query(fragment)
where = project_model_where_mapper(filters)
- return PaginatedGraphQLQuery(self.graphql_client).execute_query_from_paginated_call(
- query,
- where,
- options if options else QueryOptions(disable_tqdm=False),
- "Retrieving project models",
- None,
- )
+ return [
+ cast(ProjectModelDict, item)
+ for item in PaginatedGraphQLQuery(
+ self.graphql_client
+ ).execute_query_from_paginated_call(
+ query,
+ where,
+ options if options else QueryOptions(disable_tqdm=False),
+ "Retrieving project models",
+ None,
+ )
+ ]
+
+ def create_llm_asset(
+ self,
+ project_id: str,
+ author_id: str,
+ status: Optional[str] = None,
+ label_type: Optional[str] = None,
+ ) -> Dict:
+ """Create an LLM asset in a project, with optional status and label_type."""
+ where = map_project_where(project_id)
+ data = {"author_id": author_id, "status": status, "label_type": label_type}
+ data_mapped = map_create_llm_asset_input(data)
+ variables = {"where": where, "data": data_mapped}
+ fragment = fragment_builder(["id", "latestLabel.id"])
+ mutation = get_create_llm_asset_mutation(fragment)
+ result = self.graphql_client.execute(mutation, variables)
+ return result["createLLMAsset"]
+
+ def create_chat_item(self, asset_id: str, label_id: str, prompt: str) -> List[ChatItemDict]:
+ """Create a chat item associated with an asset."""
+ data = map_create_chat_item_input(label_id, prompt)
+ where = map_asset_where(asset_id)
+ variables = {"data": data, "where": where}
+ fragment = fragment_builder(["content", "id", "labelId", "modelId", "parentId", "role"])
+ mutation = get_create_chat_item_mutation(fragment)
+ result = self.graphql_client.execute(mutation, variables)
+ return [cast(ChatItemDict, item) for item in result["createChatItem"]]
diff --git a/src/kili/domain/llm.py b/src/kili/domain/llm.py
index 19e71d2bf..71d14c566 100644
--- a/src/kili/domain/llm.py
+++ b/src/kili/domain/llm.py
@@ -1,8 +1,10 @@
-"""API Key domain."""
+"""LLM domain."""
from dataclasses import dataclass
from enum import Enum
-from typing import Optional, Union
+from typing import Any, Dict, Optional, Union
+
+from typing_extensions import TypedDict
@dataclass
@@ -91,3 +93,50 @@ class ProjectModelFilters:
project_id: Optional[str] = None
model_id: Optional[str] = None
+
+
+class ChatItemRole(str, Enum):
+ """Enumeration of the supported chat item role."""
+
+ ASSISTANT = "ASSISTANT"
+ USER = "USER"
+
+
+class OpenAISDKCredentialsDict(TypedDict):
+ """Dict that represents model.Credentials for OpenAI SDK."""
+
+ api_key: str
+ endpoint: str
+
+
+class AzureOpenAICredentialsDict(TypedDict):
+ """Dict that represents model.Credentials for Azure OpenAI."""
+
+ api_key: str
+ endpoint: str
+ deployment_id: str
+
+
+class ModelDict(TypedDict):
+ """Dict that represents a Model."""
+
+ id: str
+ credentials: Union[AzureOpenAICredentialsDict, OpenAISDKCredentialsDict]
+ name: str
+ type: str
+
+
+class ProjectModelDict(TypedDict):
+ """Dict that represents a ProjectModel."""
+
+ id: str
+ configuration: Dict[str, Any]
+ model: ModelDict
+
+
+class ChatItemDict(TypedDict):
+ """Dict that represents a ChatItem."""
+
+ content: str
+ id: str
+ role: ChatItemRole
diff --git a/src/kili/llm/presentation/client/llm.py b/src/kili/llm/presentation/client/llm.py
index b72584cf9..b53fdc096 100644
--- a/src/kili/llm/presentation/client/llm.py
+++ b/src/kili/llm/presentation/client/llm.py
@@ -14,11 +14,14 @@
from kili.domain.asset import AssetExternalId, AssetFilters, AssetId
from kili.domain.llm import (
AzureOpenAICredentials,
+ ChatItemDict,
+ ModelDict,
ModelToCreateInput,
ModelToUpdateInput,
ModelType,
OpenAISDKCredentials,
OrganizationModelFilters,
+ ProjectModelDict,
ProjectModelFilters,
ProjectModelToCreateInput,
ProjectModelToUpdateInput,
@@ -29,21 +32,6 @@
from kili.use_cases.asset.utils import AssetUseCasesUtils
from kili.utils.logcontext import for_all_methods, log_call
-DEFAULT_ORGANIZATION_MODEL_FIELDS = [
- "id",
- "credentials",
- "name",
- "type",
-]
-
-DEFAULT_PROJECT_MODEL_FIELDS = [
- "configuration",
- "id",
- "model.credentials",
- "model.name",
- "model.type",
-]
-
@for_all_methods(log_call, exclude=["__init__"])
class LlmClientMethods:
@@ -100,26 +88,41 @@ def export(
warnings.warn(str(excp), stacklevel=2)
return None
- def models(self, organization_id: str, fields: Optional[List[str]] = None):
- """List models of given organization."""
- converted_filters = OrganizationModelFilters(
- organization_id=organization_id,
- )
+ def create_model(self, organization_id: str, model: dict) -> ModelDict:
+ # pylint: disable=line-too-long
+ """Create a new model in an organization.
- return list(
- self.kili_api_gateway.list_models(
- filters=converted_filters,
- fields=fields if fields else DEFAULT_ORGANIZATION_MODEL_FIELDS,
- )
- )
+ Args:
+ organization_id: Identifier of the organization.
+ model: A dictionary representing the model to create, containing:
+ - `name`: Name of the model.
+ - `type`: Type of the model, one of:
+ - `"AZURE_OPEN_AI"`
+ - `"OPEN_AI_SDK"`
+ - `credentials`: Credentials required for the model, depending on the type:
+ - For `"AZURE_OPEN_AI"` type:
+ - `api_key`: The API key for Azure OpenAI.
+ - `deployment_id`: The deployment ID within Azure.
+ - `endpoint`: The endpoint URL of the Azure OpenAI service.
+ - For `"OPEN_AI_SDK"` type:
+ - `api_key`: The API key for OpenAI SDK.
+ - `endpoint`: The endpoint URL of the OpenAI SDK service.
- def model(self, model_id: str, fields: Optional[List[str]] = None):
- return self.kili_api_gateway.get_model(
- model_id=model_id,
- fields=fields if fields else DEFAULT_ORGANIZATION_MODEL_FIELDS,
- )
+ Returns:
+ A dictionary containing the created model's details.
- def create_model(self, organization_id: str, model: dict):
+ Examples:
+ >>> # Example of creating an OpenAI SDK model
+ >>> model_data = {
+ ... "name": "My OpenAI SDK Model",
+ ... "type": "OPEN_AI_SDK",
+ ... "credentials": {
+ ... "api_key": "your_open_ai_api_key",
+ ... "endpoint": "https://api.openai.com/v1/"
+ ... }
+ ... }
+ >>> kili.llm.create_model(organization_id="your_organization_id", model=model_data)
+ """
credentials_data = model["credentials"]
model_type = ModelType(model["type"])
@@ -138,7 +141,74 @@ def create_model(self, organization_id: str, model: dict):
)
return self.kili_api_gateway.create_model(model=model_input)
- def update_properties_in_model(self, model_id: str, model: dict):
+ def models(self, organization_id: str, fields: Optional[List[str]] = None) -> List[ModelDict]:
+ # pylint: disable=line-too-long
+ """List models in an organization.
+
+ Args:
+ organization_id: Identifier of the organization.
+ fields: All the fields to request among the possible fields for the models.
+ Defaults to ["id", "credentials", "name", "type"].
+
+ Returns:
+ A list of models.
+
+ Examples:
+ >>> kili.llm.models(organization_id="your_organization_id")
+ """
+ converted_filters = OrganizationModelFilters(
+ organization_id=organization_id,
+ )
+
+ return list(self.kili_api_gateway.list_models(filters=converted_filters, fields=fields))
+
+ def model(self, model_id: str, fields: Optional[List[str]] = None) -> ModelDict:
+ # pylint: disable=line-too-long
+ """Retrieve a specific model.
+
+ Args:
+ model_id: Identifier of the model.
+ fields: All the fields to request among the possible fields for the models.
+ Defaults to ["id", "credentials", "name", "type"].
+
+ Returns:
+ A dictionary representing the model.
+
+ Examples:
+ >>> kili.llm.model(model_id="your_model_id")
+ """
+ return self.kili_api_gateway.get_model(
+ model_id=model_id,
+ fields=fields,
+ )
+
+ def update_properties_in_model(self, model_id: str, model: dict) -> ModelDict:
+ # pylint: disable=line-too-long
+ """Update properties of an existing model.
+
+ Args:
+ model_id: Identifier of the model to update.
+ model: A dictionary containing the properties to update, which may include:
+ - `name`: New name of the model.
+ - `credentials`: Updated credentials for the model, depending on the type:
+ - For `"AZURE_OPEN_AI"` type:
+ - `api_key`: The API key for Azure OpenAI.
+ - `deployment_id`: The deployment ID within Azure.
+ - `endpoint`: The endpoint URL of the Azure OpenAI service.
+ - For `"OPEN_AI_SDK"` type:
+ - `api_key`: The API key for OpenAI SDK.
+ - `endpoint`: The endpoint URL of the OpenAI SDK service.
+
+ Returns:
+ A dictionary containing the updated model's details.
+
+ Examples:
+ >>> # Update the name of a model
+ >>> kili.llm.update_properties_in_model(
+ ... model_id="your_model_id",
+ ... model={"name": "Updated Model Name"}
+ ... )
+ """
credentials_data = model.get("credentials")
credentials = None
@@ -165,13 +235,68 @@ def update_properties_in_model(self, model_id: str, model: dict):
model_id=model_id, model=model_input
)
- def delete_model(self, model_id: str):
+ def delete_model(self, model_id: str) -> bool:
+ # pylint: disable=line-too-long
+ """Delete a model from an organization.
+
+ Args:
+ model_id: Identifier of the model to delete.
+
+ Returns:
+ A dictionary indicating the result of the deletion.
+
+ Examples:
+ >>> kili.llm.delete_model(model_id="your_model_id")
+ """
return self.kili_api_gateway.delete_model(model_id=model_id)
+ def create_project_model(
+ self, project_id: str, model_id: str, configuration: dict
+ ) -> ProjectModelDict:
+ # pylint: disable=line-too-long
+ """Associate a model with a project.
+
+ Args:
+ project_id: Identifier of the project.
+ model_id: Identifier of the model to associate.
+ configuration: Configuration parameters for the project model.
+
+ Returns:
+ A dictionary containing the created project model's details.
+
+ Examples:
+ >>> configuration = {
+ ... # Configuration details specific to your use case
+ ... }
+ >>> kili.llm.create_project_model(
+ ... project_id="your_project_id",
+ ... model_id="your_model_id",
+ ... configuration={"temperature": 0.7}
+ ... )
+ """
+ project_model_input = ProjectModelToCreateInput(
+ project_id=project_id, model_id=model_id, configuration=configuration
+ )
+ return self.kili_api_gateway.create_project_model(project_model=project_model_input)
+
def project_models(
self, project_id: str, filters: Optional[Dict] = None, fields: Optional[List[str]] = None
- ):
- """List project models of given project."""
+ ) -> List[ProjectModelDict]:
+ """List models associated with a project.
+
+ Args:
+ project_id: Identifier of the project.
+ filters: Optional filters to apply. Possible keys:
+ - `model_id`: Identifier of a specific model to filter by.
+ fields: All the fields to request among the possible fields for the project models.
+ Defaults to ["configuration", "id", "model.credentials", "model.name", "model.type"].
+
+ Returns:
+ A list of project models.
+
+ Examples:
+ >>> kili.llm.project_models(project_id="your_project_id")
+ """
converted_filters = ProjectModelFilters(
project_id=project_id,
model_id=filters["model_id"] if filters and "model_id" in filters else None,
@@ -180,21 +305,82 @@ def project_models(
return list(
self.kili_api_gateway.list_project_models(
filters=converted_filters,
- fields=fields if fields else DEFAULT_PROJECT_MODEL_FIELDS,
+ fields=fields,
)
)
- def create_project_model(self, project_id: str, model_id: str, configuration: dict):
- project_model_input = ProjectModelToCreateInput(
- project_id=project_id, model_id=model_id, configuration=configuration
- )
- return self.kili_api_gateway.create_project_model(project_model=project_model_input)
+ def update_project_model(self, project_model_id: str, configuration: dict) -> ProjectModelDict:
+ """Update the configuration of a project model.
+
+ Args:
+ project_model_id: Identifier of the project model to update.
+ configuration: New configuration parameters.
- def update_project_model(self, project_model_id: str, configuration: dict):
+ Returns:
+ A dictionary containing the updated project model's details.
+
+ Examples:
+ >>> configuration = {
+ ... # Updated configuration details
+ ... }
+ >>> kili.llm.update_project_model(
+ ... project_model_id="your_project_model_id",
+ ... configuration=configuration
+ ... )
+ """
project_model_input = ProjectModelToUpdateInput(configuration=configuration)
return self.kili_api_gateway.update_project_model(
project_model_id=project_model_id, project_model=project_model_input
)
- def delete_project_model(self, project_model_id: str):
+ def delete_project_model(self, project_model_id: str) -> bool:
+ """Delete a project model.
+
+ Args:
+ project_model_id: Identifier of the project model to delete.
+
+ Returns:
+ A dictionary indicating the result of the deletion.
+
+ Examples:
+ >>> kili.llm.delete_project_model(project_model_id="your_project_model_id")
+ """
return self.kili_api_gateway.delete_project_model(project_model_id)
+
+ def create_conversation(self, project_id: str, prompt: str) -> List[ChatItemDict]:
+ # pylint: disable=line-too-long
+ """Create a new conversation in an LLM project starting with a user's prompt.
+
+ This method initiates a new conversation in the specified project by:
+ - Creating an LLM asset and label associated with the current user.
+ - Adding the user's prompt as the first chat item.
+ - Automatically generating assistant responses using the project's models.
+
+ Args:
+ project_id: The identifier of the project where the conversation will be created.
+ prompt: The initial prompt or message from the user to start the conversation.
+
+ Returns:
+ A list of chat items in the conversation, including the user's prompt and the assistant's responses.
+
+ Examples:
+ >>> PROMPT = "Hello, how can I improve my coding skills?"
+ >>> chat_items = kili.llm.create_conversation(project_id="your_project_id", prompt=PROMPT)
+
+ Notes:
+ - The first chat item corresponds to the user's prompt.
+ - The subsequent chat items are assistant responses generated by the project's models.
+ - An LLM asset and a label are created in the project with status "TODO" and labelType "PREDICTION".
+ """
+ user_id = self.kili_api_gateway.get_current_user(["id"])["id"]
+ llm_asset = self.kili_api_gateway.create_llm_asset(
+ project_id=project_id,
+ author_id=user_id,
+ status="TODO",
+ label_type="PREDICTION",
+ )
+ asset_id = llm_asset["id"]
+ label_id = llm_asset["latestLabel"]["id"]
+ return self.kili_api_gateway.create_chat_item(
+ asset_id=asset_id, label_id=label_id, prompt=prompt
+ )
diff --git a/tests/e2e/test_e2e_models.py b/tests/e2e/test_e2e_models.py
index feb742415..ecc8b5367 100644
--- a/tests/e2e/test_e2e_models.py
+++ b/tests/e2e/test_e2e_models.py
@@ -1,64 +1,66 @@
import pytest
from kili.client import Kili
+from kili.domain.llm import ChatItemRole
from kili.exceptions import GraphQLError
-
-@pytest.mark.e2e()
-def test_given_no_resources_when_creating_project_and_model_it_creates_and_manages_resources_correctly(
- kili: Kili,
-):
- project_title = "[E2E Test]: Model"
- project_description = "End-to-End Test Model and Project Model workflow"
- interface = {
- "jobs": {
- "COMPARISON_JOB": {
- "content": {
- "options": {
- "OPTION_A": {"children": [], "name": "Option A", "id": "optionA"},
- "OPTION_B": {"children": [], "name": "Option B", "id": "optionB"},
- },
- "input": "radio",
+PROJECT_TITLE = "[E2E Test]: Model"
+PROJECT_DESCRIPTION = "End-to-End Test Model and Project Model workflow"
+MODEL_NAME = "E2E Test Model"
+UPDATED_MODEL_NAME = "E2E Test Model Updated"
+PROMPT = "Hello, world !"
+
+INTERFACE = {
+ "jobs": {
+ "COMPARISON_JOB": {
+ "content": {
+ "options": {
+ "OPTION_A": {"children": [], "name": "Option A", "id": "optionA"},
+ "OPTION_B": {"children": [], "name": "Option B", "id": "optionB"},
},
- "instruction": "Select the best option",
- "mlTask": "COMPARISON",
- "required": 1,
- "isChild": False,
- "isNew": False,
- }
+ "input": "radio",
+ },
+ "instruction": "Select the best option",
+ "mlTask": "COMPARISON",
+ "required": 1,
+ "isChild": False,
+ "isNew": False,
}
}
+}
+
+@pytest.mark.e2e()
+def test_create_and_manage_project_and_model_resources(kili: Kili):
+ """Test the creation and management of project and model resources."""
organization_id = kili.organizations()[0]["id"]
project = kili.create_project(
- title=project_title,
- description=project_description,
+ title=PROJECT_TITLE,
+ description=PROJECT_DESCRIPTION,
input_type="LLM_INSTR_FOLLOWING",
- json_interface=interface,
+ json_interface=INTERFACE,
)
project_id = project["id"]
model_data = {
"credentials": {"api_key": "***", "endpoint": "https://api.openai.com"},
- "name": "E2E Test Model",
+ "name": MODEL_NAME,
"type": "OPEN_AI_SDK",
}
+
model = kili.llm.create_model(organization_id=organization_id, model=model_data)
model_id = model["id"]
created_model = kili.llm.model(model_id)
- assert created_model["name"] == model_data["name"]
+ assert created_model["name"] == MODEL_NAME
assert created_model["type"] == model_data["type"]
- updated_model_name = "E2E Test Model Updated"
- kili.llm.update_properties_in_model(
+ updated_model = kili.llm.update_properties_in_model(
model_id=model_id,
- model={"credentials": model_data["credentials"], "name": updated_model_name},
+ model={"credentials": model_data["credentials"], "name": UPDATED_MODEL_NAME},
)
-
- updated_model = kili.llm.model(model_id)
- assert updated_model["name"] == updated_model_name
+ assert updated_model["name"] == UPDATED_MODEL_NAME
project_model_config_1 = {
"model": "Test Model",
@@ -81,11 +83,16 @@ def test_given_no_resources_when_creating_project_and_model_it_creates_and_manag
project_models = kili.llm.project_models(project_id=project_id)
assert len(project_models) == 2
- first_project_model = project_models[0]
- assert first_project_model["id"] == project_model_id_1
+
+ def get_project_model_by_id(models, model_id):
+ return next((pm for pm in models if pm["id"] == model_id), None)
+
+ first_project_model = get_project_model_by_id(project_models, project_model_id_1)
+ assert first_project_model is not None
assert first_project_model["configuration"]["temperature"] == 0.5
- second_project_model = project_models[1]
- assert second_project_model["id"] == project_model_id_2
+
+ second_project_model = get_project_model_by_id(project_models, project_model_id_2)
+ assert second_project_model is not None
assert second_project_model["configuration"]["temperature"]["min"] == 0.2
assert second_project_model["configuration"]["temperature"]["max"] == 0.8
@@ -94,16 +101,29 @@ def test_given_no_resources_when_creating_project_and_model_it_creates_and_manag
project_model_id=project_model_id_1, configuration=updated_project_model_config_1
)
- project_models = kili.llm.project_models(project_id=project_id)
-
- assert len(project_models) == 2
- updated_project_model = next(
- project_model
- for project_model in kili.llm.project_models(project_id=project_id)
- if project_model["id"] == project_model_id_1
- )
+ updated_project_models = kili.llm.project_models(project_id=project_id)
+ updated_project_model = get_project_model_by_id(updated_project_models, project_model_id_1)
+ assert updated_project_model is not None
assert updated_project_model["configuration"] == updated_project_model_config_1
+ chat_items = kili.llm.create_conversation(project_id=project_id, prompt=PROMPT)
+
+ assert len(chat_items) == 3
+ assert chat_items[0]["content"] == PROMPT
+ assert chat_items[0]["role"] == ChatItemRole.USER
+ assert chat_items[1]["role"] == ChatItemRole.ASSISTANT
+ assert chat_items[2]["role"] == ChatItemRole.ASSISTANT
+
+ assets = kili.assets(project_id)
+ assert len(assets) == 1
+ created_asset = assets[0]
+ assert created_asset["status"] == "TODO"
+
+ labels = kili.labels(project_id)
+ assert len(labels) == 1
+ created_label = labels[0]
+ assert created_label["labelType"] == "PREDICTION"
+
kili.llm.delete_project_model(project_model_id=project_model_id_1)
kili.llm.delete_project_model(project_model_id=project_model_id_2)
diff --git a/tests/e2e/test_notebooks.py b/tests/e2e/test_notebooks.py
index 71542f840..9ca4c76e3 100644
--- a/tests/e2e/test_notebooks.py
+++ b/tests/e2e/test_notebooks.py
@@ -45,6 +45,7 @@ def process_notebook(notebook_filename: str) -> None:
"recipes/importing_video_assets.ipynb",
"recipes/inference_labels.ipynb",
"recipes/label_parsing.ipynb",
+ "recipes/llm_project_setup.ipynb",
"recipes/medical_imaging.ipynb",
# "recipes/ner_pre_annotations_openai.ipynb",
"recipes/ocr_pre_annotations.ipynb",
diff --git a/tests/unit/llm/test_create_conversation.py b/tests/unit/llm/test_create_conversation.py
new file mode 100644
index 000000000..c3be30581
--- /dev/null
+++ b/tests/unit/llm/test_create_conversation.py
@@ -0,0 +1,31 @@
+from kili.llm.presentation.client.llm import LlmClientMethods
+
+
+def test_create_conversation(mocker):
+ mock_get_current_user = {"id": "user_id"}
+ mock_llm_asset = {"id": "asset_id", "latestLabel": {"id": "label_id"}}
+ mock_chat_item = {
+ "id": "chat_item_id",
+ "asset_id": "asset_id",
+ "label_id": "label_id",
+ "prompt": "prompt text",
+ }
+
+ kili_api_gateway = mocker.MagicMock()
+ kili_api_gateway.get_current_user.return_value = mock_get_current_user
+ kili_api_gateway.create_llm_asset.return_value = mock_llm_asset
+ kili_api_gateway.create_chat_item.return_value = mock_chat_item
+
+ kili_llm = LlmClientMethods(kili_api_gateway)
+
+ result = kili_llm.create_conversation(project_id="project_id", prompt="prompt text")
+
+ assert result == mock_chat_item
+
+ kili_api_gateway.get_current_user.assert_called_once_with(["id"])
+ kili_api_gateway.create_llm_asset.assert_called_once_with(
+ project_id="project_id", author_id="user_id", status="TODO", label_type="PREDICTION"
+ )
+ kili_api_gateway.create_chat_item.assert_called_once_with(
+ asset_id="asset_id", label_id="label_id", prompt="prompt text"
+ )
diff --git a/tests/unit/llm/test_model.py b/tests/unit/llm/test_model.py
index efec6104c..9656427b4 100644
--- a/tests/unit/llm/test_model.py
+++ b/tests/unit/llm/test_model.py
@@ -1,5 +1,16 @@
import pytest
+from kili.adapters.kili_api_gateway.llm.mappers import (
+ map_create_model_input,
+ map_update_model_input,
+)
+from kili.domain.llm import (
+ AzureOpenAICredentials,
+ ModelToCreateInput,
+ ModelToUpdateInput,
+ ModelType,
+ OpenAISDKCredentials,
+)
from kili.llm.presentation.client.llm import LlmClientMethods
mock_list_models = [
@@ -49,6 +60,143 @@
mock_delete_model = {"id": "model_id"}
+def test_map_create_model_input_with_openai_sdk_credentials():
+ credentials = OpenAISDKCredentials(api_key="api_key", endpoint="https://api.openai.com/v1/")
+ input_data = ModelToCreateInput(
+ name="Test Model",
+ type=ModelType.OPEN_AI_SDK,
+ organization_id="org_id",
+ credentials=credentials,
+ )
+ expected_output = {
+ "credentials": {
+ "apiKey": "api_key",
+ "endpoint": "https://api.openai.com/v1/",
+ },
+ "name": "Test Model",
+ "type": ModelType.OPEN_AI_SDK.value,
+ "organizationId": "org_id",
+ }
+
+ result = map_create_model_input(input_data)
+ assert result == expected_output
+
+
+def test_map_create_model_input_with_azure_openai_credentials():
+ credentials = AzureOpenAICredentials(
+ api_key="api_key",
+ deployment_id="deployment_id",
+ endpoint="https://azure-openai-endpoint.com",
+ )
+ input_data = ModelToCreateInput(
+ name="Test Azure Model",
+ type=ModelType.AZURE_OPEN_AI,
+ organization_id="org_id",
+ credentials=credentials,
+ )
+ expected_output = {
+ "credentials": {
+ "apiKey": "api_key",
+ "deploymentId": "deployment_id",
+ "endpoint": "https://azure-openai-endpoint.com",
+ },
+ "name": "Test Azure Model",
+ "type": ModelType.AZURE_OPEN_AI.value,
+ "organizationId": "org_id",
+ }
+
+ result = map_create_model_input(input_data)
+ assert result == expected_output
+
+
+def test_map_update_model_input_update_name_only():
+ input_data = ModelToUpdateInput(name="Updated Model Name")
+ expected_output = {"name": "Updated Model Name"}
+
+ result = map_update_model_input(input_data)
+ assert result == expected_output
+
+
+def test_map_update_model_input_update_openai_sdk_credentials():
+ credentials = OpenAISDKCredentials(
+ api_key="new_api_key", endpoint="https://new-openai-endpoint.com"
+ )
+ input_data = ModelToUpdateInput(credentials=credentials)
+ expected_output = {
+ "credentials": {
+ "apiKey": "new_api_key",
+ "endpoint": "https://new-openai-endpoint.com",
+ }
+ }
+
+ result = map_update_model_input(input_data)
+ assert result == expected_output
+
+
+def test_map_update_model_input_update_azure_openai_credentials():
+ credentials = AzureOpenAICredentials(
+ api_key="new_api_key",
+ deployment_id="new_deployment_id",
+ endpoint="https://new-azure-openai-endpoint.com",
+ )
+ input_data = ModelToUpdateInput(credentials=credentials)
+ expected_output = {
+ "credentials": {
+ "apiKey": "new_api_key",
+ "deploymentId": "new_deployment_id",
+ "endpoint": "https://new-azure-openai-endpoint.com",
+ }
+ }
+
+ result = map_update_model_input(input_data)
+ assert result == expected_output
+
+
+def test_map_update_model_input_update_name_and_openai_sdk_credentials():
+ credentials = OpenAISDKCredentials(
+ api_key="new_api_key", endpoint="https://new-openai-endpoint.com"
+ )
+ input_data = ModelToUpdateInput(name="Updated Model Name", credentials=credentials)
+ expected_output = {
+ "name": "Updated Model Name",
+ "credentials": {
+ "apiKey": "new_api_key",
+ "endpoint": "https://new-openai-endpoint.com",
+ },
+ }
+
+ result = map_update_model_input(input_data)
+ assert result == expected_output
+
+
+def test_map_update_model_input_update_name_and_azure_openai_credentials():
+ credentials = AzureOpenAICredentials(
+ api_key="new_api_key",
+ deployment_id="new_deployment_id",
+ endpoint="https://new-azure-openai-endpoint.com",
+ )
+ input_data = ModelToUpdateInput(name="Updated Model Name", credentials=credentials)
+ expected_output = {
+ "name": "Updated Model Name",
+ "credentials": {
+ "apiKey": "new_api_key",
+ "deploymentId": "new_deployment_id",
+ "endpoint": "https://new-azure-openai-endpoint.com",
+ },
+ }
+
+ result = map_update_model_input(input_data)
+ assert result == expected_output
+
+
+def test_map_update_model_input_no_updates():
+ input_data = ModelToUpdateInput()
+ expected_output = {}
+
+ result = map_update_model_input(input_data)
+ assert result == expected_output
+
+
def test_list_models(mocker):
kili_api_gateway = mocker.MagicMock()
kili_api_gateway.list_models.return_value = mock_list_models