Skip to content

Commit

Permalink
[Bug fix] Fix history sequence in prompt (#1254)
Browse files Browse the repository at this point in the history
  • Loading branch information
deshraj authored Feb 12, 2024
1 parent d38120c commit 2f285ea
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 19 deletions.
3 changes: 2 additions & 1 deletion embedchain/config/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
$context
History: $history
History:
$history
Query: $query
Expand Down
18 changes: 11 additions & 7 deletions embedchain/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from langchain.schema import BaseMessage as LCBaseMessage

from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE)
from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage
Expand Down Expand Up @@ -65,6 +63,14 @@ def add_history(
self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id)
self.update_history(app_id=app_id, session_id=session_id)

def _format_history(self) -> str:
"""Format history to be used in prompt
:return: Formatted history
:rtype: str
"""
return "\n".join(self.history)

def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str:
"""
Generates a prompt based on the given query and context, ready to be
Expand All @@ -84,10 +90,8 @@ def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[

prompt_contains_history = self.config._validate_prompt_history(self.config.prompt)
if prompt_contains_history:
# Prompt contains history
# If there is no history yet, we insert `- no history -`
prompt = self.config.prompt.substitute(
context=context_string, query=input_query, history=self.history or "- no history -"
context=context_string, query=input_query, history=self._format_history() or "No history"
)
elif self.history and not prompt_contains_history:
# History is present, but not included in the prompt.
Expand All @@ -98,7 +102,7 @@ def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[
):
# swap in the template with history
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
context=context_string, query=input_query, history=self.history
context=context_string, query=input_query, history=self._format_history()
)
else:
# If we can't swap in the default, we still proceed but tell users that the history is ignored.
Expand Down
14 changes: 6 additions & 8 deletions embedchain/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,24 @@ def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str:
if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p
if config.stream:
from langchain.callbacks.streaming_stdout import \
StreamingStdOutCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
llm = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
else:
chat = ChatOpenAI(**kwargs, api_key=api_key)
llm = ChatOpenAI(**kwargs, api_key=api_key)

if self.functions is not None:
from langchain.chains.openai_functions import \
create_openai_fn_runnable
from langchain.chains.openai_functions import create_openai_fn_runnable
from langchain.prompts import ChatPromptTemplate

structured_prompt = ChatPromptTemplate.from_messages(messages)
runnable = create_openai_fn_runnable(functions=self.functions, prompt=structured_prompt, llm=chat)
runnable = create_openai_fn_runnable(functions=self.functions, prompt=structured_prompt, llm=llm)
fn_res = runnable.invoke(
{
"input": prompt,
}
)
messages.append(AIMessage(content=json.dumps(fn_res)))

return chat(messages).content
return llm(messages).content
4 changes: 2 additions & 2 deletions embedchain/memory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ def get(
"""

if fetch_all:
additional_query = "ORDER BY created_at DESC"
additional_query = "ORDER BY created_at ASC"
params = (app_id,)
else:
additional_query = """
AND session_id=?
ORDER BY created_at DESC
ORDER BY created_at ASC
LIMIT ?
"""
params = (app_id, session_id, num_rounds)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "embedchain"
version = "0.1.76"
version = "0.1.77"
description = "Simplest open source retrieval(RAG) framework"
authors = [
"Taranjeet Singh <[email protected]>",
Expand Down

0 comments on commit 2f285ea

Please sign in to comment.