Skip to content

Commit

Permalink
feat(LAB-3244): on LLM dynamic projects export annotations at convers…
Browse files Browse the repository at this point in the history
…ation level
  • Loading branch information
FannyGaudin committed Nov 7, 2024
1 parent dc48337 commit a39328b
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 19 deletions.
43 changes: 24 additions & 19 deletions src/kili/llm/services/export/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def export(
chat_items = label["chatItems"]
annotations = label["annotations"]
rounds = self._build_rounds(chat_items, annotations, json_interface)
total_rounds = len(rounds)
for step, round in enumerate(rounds):
raw_data = _format_raw_data(
round["context"]
Expand All @@ -60,25 +61,28 @@ def export(
label["id"],
obfuscated_models,
)
formatted_response = _format_json_response(
json_interface["jobs"],
round["annotations"],
round["completion"],
obfuscated_models,
)
label_data = {
"author": label["author"]["email"],
"created_at": label["createdAt"],
"label_type": label["labelType"],
"label": formatted_response["label"],
}
if step == total_rounds - 1 and formatted_response["conversation_label"]:
label_data["conversation_label"] = formatted_response["conversation_label"]

result[f"{step}"] = {
"external_id": asset["externalId"],
"metadata": asset["jsonMetadata"],
"models": _format_models_object(
asset["assetProjectModels"], obfuscated_models
),
"labels": [
{
"author": label["author"]["email"],
"created_at": label["createdAt"],
"label_type": label["labelType"],
"label": _format_json_response(
json_interface["jobs"],
round["annotations"],
round["completion"],
obfuscated_models,
),
}
],
"labels": [label_data],
"raw_data": raw_data,
"status": asset["status"],
}
Expand Down Expand Up @@ -123,10 +127,9 @@ def _init_round(self, context):

def _build_rounds(self, chat_items, annotations, json_interface):
"""A round is composed of a prompt with n pre-prompts and n completions."""
ordered_chat_items = sorted(chat_items, key=lambda x: x["createdAt"])
rounds = []
current_round = self._init_round([])
for chat_item in ordered_chat_items:
for chat_item in chat_items:
role = chat_item["role"].lower() if chat_item["role"] else None
if role == "user" or role == "system":
if current_round["prompt"] is not None:
Expand Down Expand Up @@ -156,7 +159,7 @@ def _build_rounds(self, chat_items, annotations, json_interface):
current_round["annotations"] += [
annotation
for annotation in annotations
if annotation["chatItemId"] == chat_item["id"]
if annotation["chatItemId"] == chat_item["id"] or annotation["chatItemId"] is None
]
rounds.append(current_round)
return rounds
Expand Down Expand Up @@ -192,8 +195,8 @@ def _format_comparison_annotation(annotation, completions, job, obfuscated_model

def _format_json_response(
jobs_config: Dict, annotations: List[Dict], completions: List[Dict], obfuscated_models: Dict
) -> Dict[str, Union[str, List[str]]]:
result = {}
) -> Dict[str, Dict]:
result = {"label": {}, "conversation_label": {}}
for annotation in annotations:
formatted_response = None
job = jobs_config[annotation["job"]]
Expand All @@ -210,8 +213,10 @@ def _format_json_response(
logging.warning(
f"Annotation with job {annotation['job']} with mlTask {job['mlTask']} not supported. Ignored in the export."
)
elif "level" in job and job["level"] == "conversation":
result["conversation_label"][annotation["job"]] = formatted_response
else:
result[annotation["job"]] = formatted_response
result["label"][annotation["job"]] = formatted_response

return result

Expand Down
50 changes: 50 additions & 0 deletions tests/unit/llm/services/export/test_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,30 @@
"isChild": False,
"isNew": True,
},
"CLASSIFICATION_JOB_0": {
"content": {
"categories": {
"GOOD": {"children": [], "name": "Good", "id": "category7"},
"BAD": {"children": [], "name": "Bad", "id": "category8"},
},
"input": "radio",
},
"level": "conversation",
"instruction": "Overall quality",
"mlTask": "CLASSIFICATION",
"required": 1,
"isChild": False,
"isNew": False,
},
"TRANSCRIPTION_JOB": {
"content": {"input": "textField"},
"level": "conversation",
"instruction": "Write something about the overall quality",
"mlTask": "TRANSCRIPTION",
"required": 1,
"isChild": False,
"isNew": False,
},
}
}

Expand Down Expand Up @@ -102,6 +126,28 @@
},
"chatItemId": "clzieuheg00587tc9d2k53ee1",
},
{
"id": "20241025134207822-9",
"job": "CLASSIFICATION_JOB_0",
"path": [],
"labelId": "clzief6q2003e7tc91jm46uii",
"chatItemId": None,
"annotationValue": {
"categories": ["GOOD"],
},
"__typename": "ClassificationAnnotation",
},
{
"id": "20241025134209366-10",
"job": "TRANSCRIPTION_JOB",
"path": [],
"labelId": "clzief6q2003e7tc91jm46uii",
"chatItemId": None,
"annotationValue": {
"text": "something",
},
"__typename": "TranscriptionAnnotation",
},
],
"author": {
"id": "user-1",
Expand Down Expand Up @@ -443,6 +489,10 @@
"created_at": "2024-08-06T12:30:42.122Z",
"label_type": "DEFAULT",
"label": {"COMPARISON_JOB": "A_2"},
"conversation_label": {
"CLASSIFICATION_JOB_0": ["GOOD"],
"TRANSCRIPTION_JOB": "something",
},
}
],
},
Expand Down

0 comments on commit a39328b

Please sign in to comment.