diff --git a/elleelleaime/generate/strategies/models/mistral/__init__.py b/elleelleaime/generate/strategies/models/mistral/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/elleelleaime/generate/strategies/models/mistral/mistral.py b/elleelleaime/generate/strategies/models/mistral/mistral.py new file mode 100644 index 00000000..a32fad11 --- /dev/null +++ b/elleelleaime/generate/strategies/models/mistral/mistral.py @@ -0,0 +1,45 @@ +from elleelleaime.generate.strategies.strategy import PatchGenerationStrategy + +from dotenv import load_dotenv +from typing import Any, List + +import os +import mistralai +import backoff + + +class MistralModels(PatchGenerationStrategy): + def __init__(self, model_name: str, **kwargs) -> None: + self.model_name = model_name + self.temperature = kwargs.get("temperature", 0.0) + self.n_samples = kwargs.get("n_samples", 1) + + load_dotenv() + self.client = mistralai.Mistral(os.getenv("MISTRAL_API_KEY", None)) + + @backoff.on_exception( + backoff.expo, + ( + mistralai.models.SDKError, + mistralai.models.HTTPValidationError, + AssertionError, + ), + ) + def _completions_with_backoff(self, **kwargs): + response = self.client.chat.complete(**kwargs) + assert response is not None + return response + + def _generate_impl(self, chunk: List[str]) -> Any: + result = [] + + for prompt in chunk: + completion = self._completions_with_backoff( + model=self.model_name, + messages=[{"role": "user", "content": prompt}], + temperature=self.temperature, + n=self.n_samples, + ) + result.append(completion.model_dump()) + + return result diff --git a/elleelleaime/generate/strategies/registry.py b/elleelleaime/generate/strategies/registry.py index 0362d13f..5015154d 100644 --- a/elleelleaime/generate/strategies/registry.py +++ b/elleelleaime/generate/strategies/registry.py @@ -17,6 +17,9 @@ from elleelleaime.generate.strategies.models.anthropic.anthropic import ( AnthropicModels, ) +from elleelleaime.generate.strategies.models.mistral.mistral import ( + MistralModels, +) from typing import Tuple @@ -35,6 +38,7 @@ class PatchGenerationStrategyRegistry: "codellama-infilling": (CodeLLaMAInfilling, ("model_name",)), "codellama-instruct": (CodeLLaMAIntruct, ("model_name",)), "anthropic": (AnthropicModels, ("model_name", "max_tokens")), + "mistral": (MistralModels, ("model_name",)), } @classmethod diff --git a/poetry.lock b/poetry.lock index db172b1f..6d61231c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "accelerate" @@ -547,6 +547,20 @@ files = [ {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, ] +[[package]] +name = "eval-type-backport" +version = "0.2.0" +description = "Like `typing._eval_type`, but lets older Python versions use newer typing features." +optional = false +python-versions = ">=3.8" +files = [ + {file = "eval_type_backport-0.2.0-py3-none-any.whl", hash = "sha256:ac2f73d30d40c5a30a80b8739a789d6bb5e49fdffa66d7912667e2015d9c9933"}, + {file = "eval_type_backport-0.2.0.tar.gz", hash = "sha256:68796cfbc7371ebf923f03bdf7bef415f3ec098aeced24e054b253a0e78f7b37"}, +] + +[package.extras] +tests = ["pytest"] + [[package]] name = "evaluate" version = "0.4.3" @@ -1226,6 +1240,17 @@ files = [ {file = "jiter-0.7.1.tar.gz", hash = "sha256:448cf4f74f7363c34cdef26214da527e8eeffd88ba06d0b80b485ad0667baf5d"}, ] +[[package]] +name = "jsonpath-python" +version = "1.0.6" +description = "A more powerful JSONPath implementation in modern python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "jsonpath-python-1.0.6.tar.gz", hash = "sha256:dd5be4a72d8a2995c3f583cf82bf3cd1a9544cfdabf2d22595b67aff07349666"}, + {file = "jsonpath_python-1.0.6-py3-none-any.whl", hash = "sha256:1e3b78df579f5efc23565293612decee04214609208a2335884b3ee3f786b575"}, +] + [[package]] name = "markupsafe" version = "3.0.2" @@ -1296,6 +1321,28 @@ files = [ {file = "markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0"}, ] +[[package]] +name = "mistralai" +version = "1.2.3" +description = "Python Client SDK for the Mistral AI API." +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "mistralai-1.2.3-py3-none-any.whl", hash = "sha256:23902852829d1961f73cf1ecd387e8940f909f5b507c5f7fd32c7dae1a033119"}, + {file = "mistralai-1.2.3.tar.gz", hash = "sha256:096b1406f62d8262d06d3f2f826714b2da87540c9e8d829864702918149c3615"}, +] + +[package.dependencies] +eval-type-backport = ">=0.2.0,<0.3.0" +httpx = ">=0.27.0,<0.28.0" +jsonpath-python = ">=1.0.6,<2.0.0" +pydantic = ">=2.9.0,<3.0.0" +python-dateutil = "2.8.2" +typing-inspect = ">=0.9.0,<0.10.0" + +[package.extras] +gcp = ["google-auth (==2.27.0)", "requests (>=2.32.3,<3.0.0)"] + [[package]] name = "mpmath" version = "1.3.0" @@ -2327,13 +2374,13 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments [[package]] name = "python-dateutil" -version = "2.9.0.post0" +version = "2.8.2" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ - {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, - {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, + {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, + {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, ] [package.dependencies] @@ -3148,6 +3195,21 @@ files = [ {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] +[[package]] +name = "typing-inspect" +version = "0.9.0" +description = "Runtime inspection utilities for typing module." +optional = false +python-versions = "*" +files = [ + {file = "typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f"}, + {file = "typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78"}, +] + +[package.dependencies] +mypy-extensions = ">=0.3.0" +typing-extensions = ">=3.7.4" + [[package]] name = "tzdata" version = "2024.2" @@ -3429,4 +3491,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "092f9bfd0912968ecb67f237a04d519cca128a5f71120d020191f823e9e5c206" +content-hash = "1062ddeb7404bd9cb7ea739abc6a12bad096d497dc8163fc8189c725ec1aa72f" diff --git a/pyproject.toml b/pyproject.toml index 5eb6c686..c701aa52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ evaluate = "^0.4.2" safetensors = "^0.4.3" google-generativeai = "^0.7.2" anthropic = "^0.34.2" +mistralai = "^1.2.3" [tool.poetry.group.dev.dependencies] pytest = "^8.0.0"