RustX commited on
Commit
a9d3fa8
ยท
1 Parent(s): 1cf7a10

Create chatbot.py

Browse files
Files changed (1) hide show
  1. modules/chatbot.py +49 -0
modules/chatbot.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain.chat_models import ChatOpenAI
3
+ from langchain.chains import ConversationalRetrievalChain
4
+ from langchain.prompts.prompt import PromptTemplate
5
+
6
+
7
+ class Chatbot:
8
+ _template = """๋‹ค์Œ ๋Œ€ํ™”์™€ ํ›„์† ์งˆ๋ฌธ์ด ์ฃผ์–ด์ง€๋ฉด ํ›„์† ์งˆ๋ฌธ์„ ๋…๋ฆฝํ˜• ์งˆ๋ฌธ์œผ๋กœ ๋ฐ”๊พธ์‹ญ์‹œ์˜ค.
9
+ ์งˆ๋ฌธ์ด CSV ํŒŒ์ผ์˜ ์ •๋ณด์— ๊ด€ํ•œ ๊ฒƒ์ด๋ผ๊ณ  ๊ฐ€์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
10
+ Chat History:
11
+ {chat_history}
12
+ Follow-up entry: {question}
13
+ Standalone question:"""
14
+
15
+ CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
16
+
17
+ qa_template = """"csv ํŒŒ์ผ์˜ ์ •๋ณด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์งˆ๋ฌธ์— ๋‹ตํ•˜๋Š” AI ๋Œ€ํ™” ๋น„์„œ์ž…๋‹ˆ๋‹ค.
18
+ csv ํŒŒ์ผ์˜ ๋ฐ์ดํ„ฐ์™€ ์งˆ๋ฌธ์ด ์ œ๊ณต๋˜๋ฉฐ ์‚ฌ์šฉ์ž๊ฐ€ ํ•„์š”ํ•œ ์ •๋ณด๋ฅผ ์ฐพ๋„๋ก ๋„์™€์•ผ ํ•ฉ๋‹ˆ๋‹ค.
19
+ ์•Œ๊ณ  ์žˆ๋Š” ์ •๋ณด์— ๋Œ€ํ•ด์„œ๋งŒ ์‘๋‹ตํ•˜์‹ญ์‹œ์˜ค. ๋‹ต์„ ์ง€์–ด๋‚ด๋ ค๊ณ  ํ•˜์ง€ ๋งˆ์„ธ์š”.
20
+ ๊ท€ํ•˜์˜ ๋‹ต๋ณ€์€ ์งง๊ณ  ์นœ๊ทผํ•˜๋ฉฐ ๋™์ผํ•œ ์–ธ์–ด๋กœ ์ž‘์„ฑ๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
21
+ question: {question}
22
+ =========
23
+ {context}
24
+ =======
25
+ """
26
+
27
+ QA_PROMPT = PromptTemplate(template=qa_template, input_variables=["question", "context"])
28
+
29
+ def __init__(self, model_name, temperature, vectors):
30
+ self.model_name = model_name
31
+ self.temperature = temperature
32
+ self.vectors = vectors
33
+
34
+ def conversational_chat(self, query):
35
+ """
36
+ Starts a conversational chat with a model via Langchain
37
+ """
38
+
39
+ chain = ConversationalRetrievalChain.from_llm(
40
+ llm=ChatOpenAI(model_name=self.model_name, temperature=self.temperature),
41
+ condense_question_prompt=self.CONDENSE_QUESTION_PROMPT,
42
+ qa_prompt=self.QA_PROMPT,
43
+ retriever=self.vectors.as_retriever(),
44
+ )
45
+ result = chain({"question": query, "chat_history": st.session_state["history"]})
46
+
47
+ st.session_state["history"].append((query, result["answer"]))
48
+
49
+ return result["answer"]