Skip to content

Commit

Permalink
Merge branch 'main' of github.com:crewAIInc/crewAI into knowledge
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzejay committed Nov 15, 2024
2 parents cb03ee6 + 3dc0231 commit cdf5233
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 89 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ dependencies = [
"pyvis>=0.3.2",
"uv>=0.4.25",
"tomli-w>=1.1.0",
"chromadb>=0.4.24",
"tomli>=2.0.2",
"chromadb>=0.5.18",
]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion src/crewai/memory/entity/entity_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, crew=None, embedder_config=None, storage=None):
if storage
else RAGStorage(
type="entities",
allow_reset=False,
allow_reset=True,
embedder_config=embedder_config,
crew=crew,
)
Expand Down
88 changes: 53 additions & 35 deletions src/crewai/memory/storage/rag_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ def __init__(self, type, allow_reset=True, embedder_config=None, crew=None):
self._initialize_app()

def _set_embedder_config(self):
import chromadb.utils.embedding_functions as embedding_functions

if self.embedder_config is None:
self.embedder_config = self._create_default_embedding_function()

Expand All @@ -61,58 +59,76 @@ def _set_embedder_config(self):
config = self.embedder_config.get("config", {})
model_name = config.get("model")
if provider == "openai":
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)

self.embedder_config = OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name,
)
elif provider == "azure":
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)

self.embedder_config = OpenAIEmbeddingFunction(
api_key=config.get("api_key"),
api_base=config.get("api_base"),
api_type=config.get("api_type", "azure"),
api_version=config.get("api_version"),
model_name=model_name,
)
elif provider == "ollama":
from openai import OpenAI

class OllamaEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
client = OpenAI(
base_url="http://localhost:11434/v1",
api_key=config.get("api_key", "ollama"),
)
try:
response = client.embeddings.create(
input=input, model=model_name
)
embeddings = [item.embedding for item in response.data]
return cast(Embeddings, embeddings)
except Exception as e:
raise e
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)

self.embedder_config = OllamaEmbeddingFunction()
self.embedder_config = OllamaEmbeddingFunction(
url=config.get("url", "http://localhost:11434/api/embeddings"),
model_name=model_name,
)
elif provider == "vertexai":
self.embedder_config = (
embedding_functions.GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)

self.embedder_config = GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
elif provider == "google":
self.embedder_config = (
embedding_functions.GoogleGenerativeAiEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
)

self.embedder_config = GoogleGenerativeAiEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
elif provider == "cohere":
self.embedder_config = embedding_functions.CohereEmbeddingFunction(
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)

self.embedder_config = CohereEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
elif provider == "bedrock":
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)

self.embedder_config = AmazonBedrockEmbeddingFunction(
session=config.get("session"),
)
elif provider == "huggingface":
self.embedder_config = embedding_functions.HuggingFaceEmbeddingServer(
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingServer,
)

self.embedder_config = HuggingFaceEmbeddingServer(
url=config.get("api_url"),
)
elif provider == "watson":
Expand Down Expand Up @@ -253,8 +269,10 @@ def reset(self) -> None:
)

def _create_default_embedding_function(self):
import chromadb.utils.embedding_functions as embedding_functions
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)

return embedding_functions.OpenAIEmbeddingFunction(
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
Loading

0 comments on commit cdf5233

Please sign in to comment.