-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from openradx/rag-task-batching
Rag task batching
- Loading branch information
Showing
29 changed files
with
700 additions
and
150 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from typing import Any, Protocol | ||
|
||
from adit_radis_shared.common.mixins import ViewProtocol | ||
from django.core.paginator import EmptyPage, PageNotAnInteger, Paginator | ||
from django.db.models.query import QuerySet | ||
from django.http import HttpRequest | ||
|
||
|
||
# TODO: Move this to adit_radis_shared package. PR: https://github.com/openradx/adit-radis-shared/pull/5 | ||
class RelatedPaginationMixinProtocol(ViewProtocol, Protocol): | ||
request: HttpRequest | ||
object_list: QuerySet | ||
paginate_by: int | ||
|
||
def get_object(self) -> Any: ... | ||
|
||
def get_context_data(self, **kwargs) -> dict[str, Any]: ... | ||
|
||
def get_related_queryset(self) -> QuerySet: ... | ||
|
||
|
||
class RelatedPaginationMixin: | ||
"""This mixin provides pagination for a related queryset. This makes it possible to | ||
paginate a related queryset in a DetailView. The related queryset is obtained by | ||
the `get_related_queryset()` method that must be implemented by the subclass. | ||
If used in combination with `RelatedFilterMixin`, the `RelatedPaginationMixin` must be | ||
inherited first.""" | ||
|
||
def get_related_queryset(self: RelatedPaginationMixinProtocol) -> QuerySet: | ||
raise NotImplementedError("You must implement this method") | ||
|
||
def get_context_data(self: RelatedPaginationMixinProtocol, **kwargs): | ||
context = super().get_context_data(**kwargs) | ||
|
||
if "object_list" in context: | ||
queryset = context["object_list"] | ||
else: | ||
queryset = self.get_related_queryset() | ||
|
||
paginator = Paginator(queryset, self.paginate_by) | ||
page = self.request.GET.get("page") | ||
|
||
if page is None: | ||
page = 1 | ||
|
||
try: | ||
paginated_queryset = paginator.page(page) | ||
except PageNotAnInteger: | ||
paginated_queryset = paginator.page(1) | ||
except EmptyPage: | ||
paginated_queryset = paginator.page(paginator.num_pages) | ||
|
||
context["object_list"] = paginated_queryset | ||
context["paginator"] = paginator | ||
context["is_paginated"] = paginated_queryset.has_other_pages() | ||
context["page_obj"] = paginated_queryset | ||
|
||
return context |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import asyncio | ||
from typing import Callable, ContextManager | ||
from unittest.mock import MagicMock | ||
|
||
import pytest | ||
from faker import Faker | ||
|
||
|
||
@pytest.fixture | ||
def report_body() -> str: | ||
report_body = Faker().sentences(nb=40) | ||
return " ".join(report_body) | ||
|
||
|
||
@pytest.fixture | ||
def question_body() -> str: | ||
question_body = Faker().sentences(nb=1) | ||
return " ".join(question_body) | ||
|
||
|
||
@pytest.fixture | ||
def openai_chat_completions_mock() -> Callable[[str], ContextManager]: | ||
def _openai_chat_completions_mock(content: str) -> ContextManager: | ||
mock_openai = MagicMock() | ||
mock_response = MagicMock(choices=[MagicMock(message=MagicMock(content=content))]) | ||
future = asyncio.Future() | ||
future.set_result(mock_response) | ||
mock_openai.chat.completions.create.return_value = future | ||
|
||
return mock_openai | ||
|
||
return _openai_chat_completions_mock |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
from radis.core.utils.chat_client import AsyncChatClient | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_ask_question(report_body, question_body, openai_chat_completions_mock): | ||
openai_mock = openai_chat_completions_mock("Fake Answer") | ||
|
||
with patch("openai.AsyncOpenAI", return_value=openai_mock): | ||
answer = await AsyncChatClient().ask_question(report_body, "en", question_body) | ||
|
||
assert answer == "Fake Answer" | ||
assert openai_mock.chat.completions.create.call_count == 1 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_ask_yes_no_question(report_body, question_body, openai_chat_completions_mock): | ||
openai_yes_mock = openai_chat_completions_mock("Yes") | ||
openai_no_mock = openai_chat_completions_mock("No") | ||
|
||
with patch("openai.AsyncOpenAI", return_value=openai_yes_mock): | ||
answer = await AsyncChatClient().ask_yes_no_question(report_body, "en", question_body) | ||
|
||
assert answer == "yes" | ||
assert openai_yes_mock.chat.completions.create.call_count == 1 | ||
|
||
with patch("openai.AsyncOpenAI", return_value=openai_no_mock): | ||
answer = await AsyncChatClient().ask_yes_no_question(report_body, "en", question_body) | ||
|
||
assert answer == "no" | ||
assert openai_no_mock.chat.completions.create.call_count == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from typing import Generic, TypeVar, cast | ||
|
||
import factory | ||
from faker import Faker | ||
|
||
from radis.reports.factories import ModalityFactory | ||
|
||
from .models import Answer, Question, RagInstance, RagJob, RagTask | ||
|
||
T = TypeVar("T") | ||
|
||
fake = Faker() | ||
|
||
|
||
class BaseDjangoModelFactory(Generic[T], factory.django.DjangoModelFactory): | ||
@classmethod | ||
def create(cls, *args, **kwargs) -> T: | ||
return super().create(*args, **kwargs) | ||
|
||
|
||
SearchProviders = ("OpenSearch", "Vespa", "Elasticsearch") | ||
PatientSexes = ["", "M", "F"] | ||
|
||
|
||
class RagJobFactory(BaseDjangoModelFactory): | ||
class Meta: | ||
model = RagJob | ||
|
||
title = factory.Faker("sentence", nb_words=3) | ||
provider = factory.Faker("random_element", elements=SearchProviders) | ||
group = factory.SubFactory("adit_radis_shared.accounts.factories.GroupFactory") | ||
query = factory.Faker("word") | ||
language = factory.SubFactory("radis.reports.factories.LanguageFactory") | ||
study_date_from = factory.Faker("date") | ||
study_date_till = factory.Faker("date") | ||
study_description = factory.Faker("sentence", nb_words=5) | ||
patient_sex = factory.Faker("random_element", elements=PatientSexes) | ||
age_from = factory.Faker("random_int", min=0, max=100) | ||
age_till = factory.Faker("random_int", min=0, max=100) | ||
|
||
@factory.post_generation | ||
def modalities(self, create, extracted, **kwargs): | ||
if not create: | ||
return | ||
|
||
self = cast(RagJob, self) | ||
|
||
if extracted: | ||
for modality in extracted: | ||
self.modalities.add(modality) | ||
else: | ||
modality = ModalityFactory() | ||
self.modalities.add(modality) | ||
|
||
|
||
class QuestionFactory(BaseDjangoModelFactory[Question]): | ||
class Meta: | ||
model = Question | ||
|
||
job = factory.SubFactory("radis.rag.factories.RagJobFactory") | ||
question = factory.Faker("sentence", nb_words=10) | ||
accepted_answer = factory.Faker("random_element", elements=[a[0] for a in Answer.choices]) | ||
|
||
|
||
class RagTaskFactory(BaseDjangoModelFactory[RagTask]): | ||
class Meta: | ||
model = RagTask | ||
|
||
job = factory.SubFactory("radis.rag.factories.RagJobFactory") | ||
|
||
|
||
class RagInstanceFactory(BaseDjangoModelFactory[RagInstance]): | ||
class Meta: | ||
model = RagInstance | ||
|
||
task = factory.SubFactory("radis.rag.factories.RagTaskFactory") | ||
|
||
@factory.post_generation | ||
def reports(self, create, extracted, **kwargs): | ||
if not create: | ||
return | ||
|
||
self = cast(RagInstance, self) | ||
|
||
if extracted: | ||
for report in extracted: | ||
self.reports.add(report) | ||
else: | ||
from radis.reports.factories import ReportFactory | ||
|
||
self.reports.add(*[ReportFactory() for _ in range(3)]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.