-
Notifications
You must be signed in to change notification settings - Fork 0
/
code_query.py
49 lines (37 loc) · 1.08 KB
/
code_query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from dotenv import load_dotenv
import langchain
import os
from rich import print as rprint
import sys
langchain.debug = True
# Load environment variables from .env file
load_dotenv()
# Instantiate the EmbeddingFunction
ef = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"))
db = Chroma(
collection_name=os.getenv("COLLECTION_NAME"),
persist_directory="db",
embedding_function=ef,
)
llm = ChatOpenAI(
temperature=0.0,
openai_api_key=os.getenv("OPENAI_API_KEY"),
)
retriever = db.as_retriever()
chain = RetrievalQA.from_chain_type(llm, chain_type="stuff", retriever=retriever)
query = (
sys.argv[1]
if len(sys.argv) > 1
else "How to modify the max_tokens parameter in the continue codebase?"
)
results = chain(
{
"query": query,
}
)
rprint(f"[bold cyan]Query:[/bold cyan] {results['query']}")
rprint(f"[bold cyan]Result:[/bold cyan] {results['result']}")