diff --git a/src/kili/services/export/__init__.py b/src/kili/services/export/__init__.py index 3d2b6af12..76af1effc 100644 --- a/src/kili/services/export/__init__.py +++ b/src/kili/services/export/__init__.py @@ -72,6 +72,7 @@ def export_labels( # pylint: disable=too-many-arguments, too-many-locals "pascal_voc": VocExporter, "geojson": GeoJsonExporter, "llm_v1": LLMExporter, + "llm_dynamic_v1": LLMExporter, } assert set(format_exporter_selector_mapping.keys()) == set( get_args(LabelFormat) diff --git a/src/kili/services/export/format/llm/__init__.py b/src/kili/services/export/format/llm/__init__.py index 6390f548b..b3840f01c 100644 --- a/src/kili/services/export/format/llm/__init__.py +++ b/src/kili/services/export/format/llm/__init__.py @@ -2,12 +2,14 @@ import json import logging +from ast import literal_eval from pathlib import Path from typing import Dict, List, Optional, Union from kili.services.asset_import.helpers import SEPARATOR from kili.services.export.exceptions import NotCompatibleInputType from kili.services.export.format.base import AbstractExporter +from kili.services.export.format.llm.types import ExportLLMItem, RankingValue from kili.services.types import Job @@ -44,14 +46,48 @@ def process_and_save( self, assets: List[Dict], output_filename: Path ) -> Optional[List[Dict[str, Union[List[str], str]]]]: """LLM specific process and save.""" - result = self._process(assets) + result = self.process(assets) self._save_assets_export(result, output_filename) def process(self, assets: List[Dict]) -> List[Dict[str, Union[List[str], str]]]: """LLM specific process.""" - return self._process(assets) + if self.label_format == "llm_v1": + return self._process_llm_v1(assets) + return self._process_llm_dynamic_v1(assets) - def _process(self, assets: List[Dict]) -> List[Dict[str, Union[List[str], str]]]: + def _process_llm_dynamic_v1(self, assets: List[Dict]) -> List[Dict[str, Union[List[str], str]]]: + result = [] + for asset in assets: + step_number = _count_step(asset) + label = asset["latestLabel"] + steps = {} + context = [] + formatted_asset = _format_raw_data(asset) + for i in range(step_number): + steps[f"{i}"] = { + "raw_data": context + _format_raw_data(asset, i), + "status": asset["status"], + "external_id": asset["externalId"], + "metadata": asset["jsonMetadata"], + "labels": [ + { + "author": label["author"]["email"], + "created_at": label["createdAt"], + "label_type": label["labelType"], + "label": _format_json_response_dynamic( + self.project["jsonInterface"]["jobs"], label["jsonResponse"], i + ), + } + ], + } + next_context = _get_next_step_context(formatted_asset, label["jsonResponse"], i) + context = context + next_context + + if step_number > 0: + result.append(steps) + return result + + def _process_llm_v1(self, assets: List[Dict]) -> List[Dict[str, Union[List[str], str]]]: result = [] for asset in assets: result.append( @@ -60,26 +96,87 @@ def _process(self, assets: List[Dict]) -> List[Dict[str, Union[List[str], str]]] "status": asset["status"], "external_id": asset["externalId"], "metadata": asset["jsonMetadata"], - "labels": [ - list( - map( - lambda label: { - "author": label["author"]["email"], - "created_at": label["createdAt"], - "label_type": label["labelType"], - "label": _format_json_response( - self.project["jsonInterface"]["jobs"], label["jsonResponse"] - ), - }, - asset["labels"], - ) + "labels": list( + map( + lambda label: { + "author": label["author"]["email"], + "created_at": label["createdAt"], + "label_type": label["labelType"], + "label": _format_json_response( + self.project["jsonInterface"]["jobs"], label["jsonResponse"] + ), + }, + asset["labels"], ) - ], + ), } ) return result +def _get_step_ranking_value(json_response: Dict, step_number: int) -> RankingValue: + prefix = f"STEP_{step_number+1}_" + for category in json_response["CLASSIFICATION_JOB"]["categories"]: + if category["name"] != f"STEP_{step_number+1}": + continue + + for children_name, children_value in category["children"].items(): + if children_name == f"STEP_{step_number+1}_RANKING": + raw_value = children_value["categories"][0]["name"] + return raw_value[len(prefix) :] + return RankingValue.TIE + + +def _get_next_step_context( + formatted_asset: List[ExportLLMItem], json_response: Dict, step_number: int +) -> List[ExportLLMItem]: + context = [] + skipped_context = 0 + completion_index = 0 + ranking = _get_step_ranking_value(json_response, step_number) + for item in formatted_asset: + if skipped_context > step_number: + break + + if skipped_context == step_number: + if item["role"] == "user": + context.append(item) + else: + if completion_index == 0 and ranking in ["A_1", "A_2", "A_3", "TIE"]: + context.append(item) + break + if completion_index == 1 and ranking in ["B_1", "B_2", "B_3"]: + context.append(item) + break + completion_index += 1 + + if item["role"] == "assistant": + skipped_context += 1 + + return context + + +def _count_step(asset: Dict) -> int: + label = asset["latestLabel"] + if "jsonResponse" not in label and "CLASSIFICATION_JOB" not in label["jsonResponse"]: + return 0 + return len(label["jsonResponse"]["CLASSIFICATION_JOB"]["categories"]) + + +def _format_json_response_dynamic( + jobs_config: Dict, json_response: Dict, step_number: int +) -> Dict[str, Union[str, List[str]]]: + # check subjobs of the step + job_step = f"STEP_{step_number+1}" + for item in json_response["CLASSIFICATION_JOB"]["categories"]: + if item["name"] != job_step: + continue + response_step = _format_json_response(jobs_config, item["children"]) + formatted_response = literal_eval(str(response_step).replace(job_step + "_", "")) + return formatted_response + return {} + + def _format_json_response( jobs_config: Dict, json_response: Dict ) -> Dict[str, Union[str, List[str]]]: @@ -104,7 +201,7 @@ def _format_json_response( return result -def _format_raw_data(asset) -> List[Dict]: +def _format_raw_data(asset, step_number: Optional[int] = None) -> List[ExportLLMItem]: raw_data = [] chat_id = asset["jsonMetadata"].get("chat_id", None) @@ -131,24 +228,31 @@ def _format_raw_data(asset) -> List[Dict]: data = json.load(file) version = data.get("version", None) if version == "0.1": - for index, prompt in enumerate(data["prompts"]): + prompts = data["prompts"] + if step_number is not None: + prompts = [prompts[step_number]] + for index, prompt in enumerate(prompts): raw_data.append( - { - "role": prompt.get("title", "user"), - "content": prompt["prompt"], - "id": _safe_pop(chat_items_ids), - "chat_id": chat_id, - "model": None, - } + ExportLLMItem( + { + "role": prompt.get("title", "user"), + "content": prompt["prompt"], + "id": _safe_pop(chat_items_ids), + "chat_id": chat_id, + "model": None, + } + ) ) raw_data.extend( - { - "role": completion.get("title", "assistant"), - "content": completion["content"], - "id": _safe_pop(chat_items_ids), - "chat_id": chat_id, - "model": _safe_pop(models) if index == len(data["prompts"]) - 1 else None, - } + ExportLLMItem( + { + "role": completion.get("title", "assistant"), + "content": completion["content"], + "id": _safe_pop(chat_items_ids), + "chat_id": chat_id, + "model": _safe_pop(models) if (index == len(prompts) - 1) else None, + } + ) for completion in prompt["completions"] ) else: diff --git a/src/kili/services/export/format/llm/types.py b/src/kili/services/export/format/llm/types.py new file mode 100644 index 000000000..81eb3783b --- /dev/null +++ b/src/kili/services/export/format/llm/types.py @@ -0,0 +1,32 @@ +"""Custom Types.""" + +from enum import Enum +from typing import List, Optional, TypedDict + + +class ExportLLMItem(TypedDict): + """LLM asset chat part.""" + + role: str + content: str + id: Optional[str] + chat_id: Optional[str] + model: Optional[str] + + +class ExportLLMAsset(TypedDict): + """LLM export asset format.""" + + raw_data: List[ExportLLMItem] + + +class RankingValue(str, Enum): + """Possible value for ranking.""" + + A_3 = "A_3" + A_2 = "A_2" + A_1 = "A_1" + TIE = "TIE" + B_1 = "B_1" + B_2 = "B_2" + B_3 = "B_3" diff --git a/src/kili/services/export/types.py b/src/kili/services/export/types.py index 4151c6132..3add52afb 100644 --- a/src/kili/services/export/types.py +++ b/src/kili/services/export/types.py @@ -15,6 +15,7 @@ "pascal_voc", "geojson", "llm_v1", + "llm_dynamic_v1", ] diff --git a/tests/unit/services/export/test_llm.py b/tests/unit/services/export/test_llm.py new file mode 100644 index 000000000..32c63db48 --- /dev/null +++ b/tests/unit/services/export/test_llm.py @@ -0,0 +1,357 @@ +from unittest.mock import patch + +from kili.presentation.client.label import LabelClientMethods + +mock_json_interface = { + "jobs": { + "CLASSIFICATION_JOB": { + "content": { + "categories": { + "STEP_1": { + "children": ["STEP_1_RANKING", "STEP_1_QUALITY"], + "name": "Step 1", + "id": "category1", + }, + "STEP_2": { + "children": ["STEP_2_RANKING", "STEP_2_QUALITY"], + "name": "Step 2", + "id": "category2", + }, + }, + "input": "checkbox", + }, + "instruction": "Select the step", + "mlTask": "CLASSIFICATION", + "required": 0, + "isChild": False, + "isNew": False, + }, + "STEP_1_RANKING": { + "content": { + "categories": { + "STEP_1_A_3": { + "children": [], + "name": "A is much better than B", + "id": "category21", + }, + "STEP_1_A_2": { + "children": [], + "name": "A is better than B", + "id": "category22", + }, + "STEP_1_A_1": { + "children": [], + "name": "A is slightly better than B", + "id": "category23", + }, + "STEP_1_TIE": {"children": [], "name": "Tie", "id": "category24"}, + "STEP_1_B_1": { + "children": [], + "name": "B is slightly better than A", + "id": "category25", + }, + "STEP_1_B_2": { + "children": [], + "name": "B is better than A", + "id": "category26", + }, + "STEP_1_B_3": { + "children": [], + "name": "B is much better than A", + "id": "category27", + }, + }, + "input": "singleDropdown", + }, + "instruction": "Ranking", + "mlTask": "CLASSIFICATION", + "required": 0, + "isChild": True, + "isNew": False, + }, + "STEP_1_QUALITY": { + "content": { + "categories": { + "STEP_1_GOOD": { + "children": [], + "name": "Both answers are very good", + "id": "category28", + }, + "STEP_1_BAD": { + "children": [], + "name": "Both answers are very bad", + "id": "category29", + }, + }, + "input": "singleDropdown", + }, + "instruction": "Overall quality", + "mlTask": "CLASSIFICATION", + "required": 0, + "isChild": True, + "isNew": False, + }, + "STEP_2_RANKING": { + "content": { + "categories": { + "STEP_2_A_3": { + "children": [], + "name": "A is much better than B", + "id": "category30", + }, + "STEP_2_A_2": { + "children": [], + "name": "A is better than B", + "id": "category31", + }, + "STEP_2_A_1": { + "children": [], + "name": "A is slightly better than B", + "id": "category32", + }, + "STEP_2_TIE": {"children": [], "name": "Tie", "id": "category33"}, + "STEP_2_B_1": { + "children": [], + "name": "B is slightly better than A", + "id": "category34", + }, + "STEP_2_B_2": { + "children": [], + "name": "B is better than A", + "id": "category35", + }, + "STEP_2_B_3": { + "children": [], + "name": "B is much better than A", + "id": "category36", + }, + }, + "input": "singleDropdown", + }, + "instruction": "Ranking", + "mlTask": "CLASSIFICATION", + "required": 0, + "isChild": True, + "isNew": False, + }, + "STEP_2_QUALITY": { + "content": { + "categories": { + "STEP_2_GOOD": { + "children": [], + "name": "Both answers are very good", + "id": "category37", + }, + "STEP_2_BAD": { + "children": [], + "name": "Both answers are very bad", + "id": "category38", + }, + }, + "input": "singleDropdown", + }, + "instruction": "Overall quality", + "mlTask": "CLASSIFICATION", + "required": 0, + "isChild": True, + "isNew": False, + }, + } +} + +mock_fetch_assets = [ + { + "pageResolutions": None, + "resolution": None, + "latestLabel": { + "author": { + "id": "cl7ugav7800hw0lqghawj9d30", + "email": "jean.latapy@kili-technology.com", + "firstname": "jean", + "lastname": "latapy", + "name": "jean latapy", + }, + "jsonResponse": { + "CLASSIFICATION_JOB": { + "categories": [ + { + "name": "STEP_1", + "children": { + "STEP_1_RANKING": {"categories": [{"name": "STEP_1_B_3"}]} + }, + }, + { + "name": "STEP_2", + "children": { + "STEP_2_RANKING": {"categories": [{"name": "STEP_2_A_3"}]} + }, + }, + ] + } + }, + "createdAt": "2024-07-05T16:03:58.962Z", + "isLatestLabelForUser": True, + "labelType": "DEFAULT", + "modelName": None, + }, + "id": "cly8vy5t0000601u5ejzkw8rb", + "externalId": "42d0c18601624413bb4bf6fd9a5436df", + "content": "/tmp/content.json", + "jsonContent": "", + "jsonMetadata": {}, + "status": "LABELED", + } +] + +mock_raw_asset_content = """{ + "prompts": [ + { + "prompt": "BLABLABLA", + "completions": [ + { + "content": "Hello! How can I assist you today?", + "title": "Model A" + }, + { + "content": "Hello! How can I assist you today? If you have any questions or need information, feel free to ask. I'm here to help!", + "title": "Model B" + } + ] + }, + { + "prompt": "BLIBLIBLI", + "completions": [ + { + "content": "I apologize if I'm not understanding your message correctly. Could you please provide more context or let me know how I can assist you?", + "title": "Model A" + }, + { + "content": "It seems like you're having a bit of fun! If you have any specific questions or if there's something particular you'd like to know about, just let me know. I'm here to provide information and assist you in any way I can!", + "title": "Model B" + } + ] + } + ], + "type": "markdown", + "version": "0.1" +} +""" + +expected_export = [ + { + "0": { + "raw_data": [ + { + "role": "user", + "content": "BLABLABLA", + "id": None, + "chat_id": None, + "model": None, + }, + { + "role": "Model A", + "content": "Hello! How can I assist you today?", + "id": None, + "chat_id": None, + "model": None, + }, + { + "role": "Model B", + "content": "Hello! How can I assist you today? If you have any questions or need information, feel free to ask. I'm here to help!", + "id": None, + "chat_id": None, + "model": None, + }, + ], + "status": "LABELED", + "external_id": "42d0c18601624413bb4bf6fd9a5436df", + "metadata": {}, + "labels": [ + { + "author": "jean.latapy@kili-technology.com", + "created_at": "2024-07-05T16:03:58.962Z", + "label_type": "DEFAULT", + "label": {"RANKING": ["B_3"]}, + } + ], + }, + "1": { + "raw_data": [ + { + "role": "user", + "content": "BLABLABLA", + "id": None, + "chat_id": None, + "model": None, + }, + { + "role": "Model B", + "content": "Hello! How can I assist you today? If you have any questions or need information, feel free to ask. I'm here to help!", + "id": None, + "chat_id": None, + "model": None, + }, + { + "role": "user", + "content": "BLIBLIBLI", + "id": None, + "chat_id": None, + "model": None, + }, + { + "role": "Model A", + "content": "I apologize if I'm not understanding your message correctly. Could you please provide more context or let me know how I can assist you?", + "id": None, + "chat_id": None, + "model": None, + }, + { + "role": "Model B", + "content": "It seems like you're having a bit of fun! If you have any specific questions or if there's something particular you'd like to know about, just let me know. I'm here to provide information and assist you in any way I can!", + "id": None, + "chat_id": None, + "model": None, + }, + ], + "status": "LABELED", + "external_id": "42d0c18601624413bb4bf6fd9a5436df", + "metadata": {}, + "labels": [ + { + "author": "jean.latapy@kili-technology.com", + "created_at": "2024-07-05T16:03:58.962Z", + "label_type": "DEFAULT", + "label": {"RANKING": ["A_3"]}, + } + ], + }, + } +] + + +def test_export_dynamic(mocker): + get_project_return_val = { + "jsonInterface": mock_json_interface, + "inputType": "LLM_RLHF", + "title": "", + "id": "project_id", + "dataConnections": None, + } + kili = LabelClientMethods() + kili.api_endpoint = "https://" # type: ignore + kili.api_key = "" # type: ignore + kili.graphql_client = mocker.MagicMock() # pyright: ignore[reportGeneralTypeIssues] + kili.http_client = mocker.MagicMock() # pyright: ignore[reportGeneralTypeIssues] + kili.kili_api_gateway = mocker.MagicMock() + kili.kili_api_gateway.count_assets.return_value = 1 + kili.kili_api_gateway.get_project.return_value = get_project_return_val + with open("/tmp/content.json", "w") as file: + file.write(mock_raw_asset_content) + with patch("kili.services.export.format.base.fetch_assets") as mocked_fetch_assets: + mocked_fetch_assets.return_value = mock_fetch_assets + result = kili.export_labels( + project_id="project_id", + fmt="llm_dynamic_v1", + filename=None, + ) + assert result == expected_export