From c020e65a504e560a45bda7f60a057fe0e573f699 Mon Sep 17 00:00:00 2001 From: Deven Patel Date: Fri, 12 Jan 2024 12:33:07 +0530 Subject: [PATCH] [Improvement] update LLM memory get function (#1162) Co-authored-by: Deven Patel --- embedchain/embedchain.py | 7 +++-- embedchain/memory/base.py | 45 +++++++++++++++++++++++++------- pyproject.toml | 2 +- tests/memory/test_chat_memory.py | 4 +++ 4 files changed, 46 insertions(+), 12 deletions(-) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index fbd1a5a125..f52956a761 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -667,8 +667,11 @@ def reset(self): # Send anonymous telemetry self.telemetry.capture(event_name="reset", properties=self._telemetry_props) - def get_history(self, num_rounds: int = 10, display_format: bool = True): - return self.llm.memory.get(app_id=self.config.id, num_rounds=num_rounds, display_format=display_format) + def get_history(self, num_rounds: int = 10, display_format: bool = True, session_id: Optional[str] = "default"): + history = self.llm.memory.get( + app_id=self.config.id, session_id=session_id, num_rounds=num_rounds, display_format=display_format + ) + return history def delete_session_chat_history(self, session_id: str = "default"): self.llm.memory.delete(app_id=self.config.id, session_id=session_id) diff --git a/embedchain/memory/base.py b/embedchain/memory/base.py index 9bfa04f2de..378ad47993 100644 --- a/embedchain/memory/base.py +++ b/embedchain/memory/base.py @@ -73,21 +73,40 @@ def delete(self, app_id: str, session_id: Optional[str] = None): self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, params) self.connection.commit() - def get(self, app_id, session_id, num_rounds=10, display_format=False) -> list[ChatMessage]: + def get( + self, app_id, session_id: str = "default", num_rounds=10, fetch_all: bool = False, display_format=False + ) -> list[ChatMessage]: """ - Get the most recent num_rounds rounds of conversations - between human and AI, for a given app_id. + Get the chat history for a given app_id. + + param: app_id - The app_id to get chat history + param: session_id (optional) - The session_id to get chat history. Defaults to "default" + param: num_rounds (optional) - The number of rounds to get chat history. Defaults to 10 + param: fetch_all (optional) - Whether to fetch all chat history or not. Defaults to False + param: display_format (optional) - Whether to return the chat history in display format. Defaults to False """ - QUERY = """ + base_query = """ SELECT * FROM ec_chat_history - WHERE app_id=? AND session_id=? - ORDER BY created_at DESC - LIMIT ? + WHERE app_id=? """ + + if fetch_all: + additional_query = "ORDER BY created_at DESC" + params = (app_id,) + else: + additional_query = """ + AND session_id=? + ORDER BY created_at DESC + LIMIT ? + """ + params = (app_id, session_id, num_rounds) + + QUERY = base_query + additional_query + self.cursor.execute( QUERY, - (app_id, session_id, num_rounds), + params, ) results = self.cursor.fetchall() @@ -97,7 +116,15 @@ def get(self, app_id, session_id, num_rounds=10, display_format=False) -> list[C metadata = self._deserialize_json(metadata=metadata) # Return list of dict if display_format is True if display_format: - history.append({"human": question, "ai": answer, "metadata": metadata, "timestamp": timestamp}) + history.append( + { + "session_id": session_id, + "human": question, + "ai": answer, + "metadata": metadata, + "timestamp": timestamp, + } + ) else: memory = ChatMessage() memory.add_user_message(question, metadata=metadata) diff --git a/pyproject.toml b/pyproject.toml index 28c0979a1a..f70003c744 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.1.62" +version = "0.1.63" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" authors = [ "Taranjeet Singh ", diff --git a/tests/memory/test_chat_memory.py b/tests/memory/test_chat_memory.py index bcf6601268..6fac2a643b 100644 --- a/tests/memory/test_chat_memory.py +++ b/tests/memory/test_chat_memory.py @@ -44,6 +44,10 @@ def test_get(chat_memory_instance): assert len(recent_memories) == 5 + all_memories = chat_memory_instance.get(app_id, fetch_all=True) + + assert len(all_memories) == 6 + def test_delete_chat_history(chat_memory_instance): app_id = "test_app"