Skip to content

Commit

Permalink
feat: mistral
Browse files Browse the repository at this point in the history
  • Loading branch information
andre15silva committed Nov 19, 2024
1 parent 40f4d89 commit 3bb70a1
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 5 deletions.
Empty file.
45 changes: 45 additions & 0 deletions elleelleaime/generate/strategies/models/mistral/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from elleelleaime.generate.strategies.strategy import PatchGenerationStrategy

from dotenv import load_dotenv
from typing import Any, List

import os
import mistralai
import backoff


class MistralModels(PatchGenerationStrategy):
def __init__(self, model_name: str, **kwargs) -> None:
self.model_name = model_name
self.temperature = kwargs.get("temperature", 0.0)
self.n_samples = kwargs.get("n_samples", 1)

load_dotenv()
self.client = mistralai.Mistral(os.getenv("MISTRAL_API_KEY", None))

@backoff.on_exception(
backoff.expo,
(
mistralai.models.SDKError,
mistralai.models.HTTPValidationError,
AssertionError,
),
)
def _completions_with_backoff(self, **kwargs):
response = self.client.chat.complete(**kwargs)
assert response is not None
return response

def _generate_impl(self, chunk: List[str]) -> Any:
result = []

for prompt in chunk:
completion = self._completions_with_backoff(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
temperature=self.temperature,
n=self.n_samples,
)
result.append(completion.model_dump())

return result
4 changes: 4 additions & 0 deletions elleelleaime/generate/strategies/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from elleelleaime.generate.strategies.models.anthropic.anthropic import (
AnthropicModels,
)
from elleelleaime.generate.strategies.models.mistral.mistral import (
MistralModels,
)

from typing import Tuple

Expand All @@ -35,6 +38,7 @@ class PatchGenerationStrategyRegistry:
"codellama-infilling": (CodeLLaMAInfilling, ("model_name",)),
"codellama-instruct": (CodeLLaMAIntruct, ("model_name",)),
"anthropic": (AnthropicModels, ("model_name", "max_tokens")),
"mistral": (MistralModels, ("model_name",)),
}

@classmethod
Expand Down
72 changes: 67 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ evaluate = "^0.4.2"
safetensors = "^0.4.3"
google-generativeai = "^0.7.2"
anthropic = "^0.34.2"
mistralai = "^1.2.3"

[tool.poetry.group.dev.dependencies]
pytest = "^8.0.0"
Expand Down

0 comments on commit 3bb70a1

Please sign in to comment.