MDBRetrieve / app.py
enochsjoseph's picture
Initial commit
d6df619
import os
import requests
from dotenv import load_dotenv
import gradio as gr
import random
from text_generation import Client # Assumed custom package
# Load environment variables
load_dotenv()
hf_api_key = os.environ['HF_API_KEY']
# Initialize the client
client = Client(os.environ['HF_API_FALCOM_BASE'], headers={"Authorization": f"Basic {hf_api_key}"}, timeout=120)
# Text generation function
def generate(input_text, max_tokens):
return client.generate(input_text, max_new_tokens=max_tokens).generated_text
# Gradio interface for text generation
demo_text_gen = gr.Interface(fn=generate, inputs=[gr.Textbox(label="Prompt"), gr.Slider(label="Max new tokens", value=20, maximum=1024, minimum=1)], outputs=gr.Textbox(label="Generated Text"))
# Chat history management
def format_chat_prompt(message, chat_history):
prompt = ""
for user_msg, bot_msg in chat_history:
prompt += f"\nUser: {user_msg}\nAssistant: {bot_msg}"
return f"{prompt}\nUser: {message}\nAssistant:"
# Chatbot response generation
def respond(message, chat_history, instruction, temperature=0.7):
prompt = format_chat_prompt(message, chat_history, instruction)
response = client.generate(prompt, max_new_tokens=1024, stop_sequences=["\nUser:", ""], temperature=temperature)
chat_history.append((message, response.generated_text))
return response.generated_text, chat_history
# Gradio interface for chatbot
with gr.Blocks() as demo_chatbot:
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Your Message")
system_msg = gr.Textbox(label="System Instruction", value="A conversation with an AI.")
temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1, value=0.7)
submit_btn = gr.Button("Send")
chat_history = []
submit_btn.click(respond, inputs=[msg, chat_history, system_msg, temperature_slider], outputs=[chatbot])
msg.submit(respond, inputs=[msg, chat_history, system_msg, temperature_slider], outputs=[chatbot])
# Launch Gradio apps
if __name__ == "__main__":
gr.close_all()
demo_text_gen.launch(server_port=int(os.environ.get('PORT1', 7860)))
demo_chatbot.launch(server_port=int(os.environ.get('PORT2', 7861)))