Smart_Confidant / app.py
heffnt's picture
cs3
24ed9c5
"""
Smart Confidant - A Magic: The Gathering chatbot with support for local and API-based LLMs.
Supports both local transformers models and HuggingFace API models with custom theming.
"""
import gradio as gr
from gradio.themes.base import Base
from huggingface_hub import InferenceClient
import os
import base64
from pathlib import Path
import traceback
from datetime import datetime
from threading import Lock
import time
from prometheus_client import start_http_server, Counter, Summary, Gauge
# Load environment variables from .env file
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
# If python-dotenv not installed, skip (will use system env vars only)
pass
# ============================================================================
# Configuration
# ============================================================================
LOCAL_MODELS = ["arnir0/Tiny-LLM"]
API_MODELS = ["meta-llama/Llama-3.2-3B-Instruct"]
DEFAULT_SYSTEM_MESSAGE = "You are an expert assistant for Magic: The Gathering. You're name is Smart Confidant, but people tend to call you Bob."
TITLE = "🎓🧙🏻‍♂️ Smart Confidant 🧙🏻‍♂️🎓"
# Create labeled model options for the radio selector
MODEL_OPTIONS = []
for model in LOCAL_MODELS:
MODEL_OPTIONS.append(f"{model} (local)")
for model in API_MODELS:
MODEL_OPTIONS.append(f"{model} (api)")
# Global state for local model pipeline (cached across requests)
pipe = None
stop_inference = False
# Debug logging setup with thread-safe access
debug_logs = []
debug_lock = Lock()
MAX_LOG_LINES = 100
# ============================================================================
# Debug Logging Functions
# ============================================================================
def log_debug(message, level="INFO"):
"""Add timestamped message to debug log (thread-safe, rotating buffer)."""
timestamp = datetime.now().strftime("%H:%M:%S")
log_entry = f"[{timestamp}] [{level}] {message}"
with debug_lock:
debug_logs.append(log_entry)
if len(debug_logs) > MAX_LOG_LINES:
debug_logs.pop(0)
print(log_entry)
return "\n".join(debug_logs)
def get_debug_logs():
"""Retrieve all debug logs as a single string."""
with debug_lock:
return "\n".join(debug_logs)
# ============================================================================
# Prometheus Metrics
# ============================================================================
# Core request metrics
REQUEST_COUNTER = Counter('smart_confidant_requests_total', 'Total number of chat requests')
SUCCESSFUL_REQUESTS = Counter('smart_confidant_successful_requests_total', 'Total number of successful requests')
FAILED_REQUESTS = Counter('smart_confidant_failed_requests_total', 'Total number of failed requests')
REQUEST_DURATION = Summary('smart_confidant_request_duration_seconds', 'Time spent processing request')
# Enhanced chatbot metrics
MODEL_SELECTION_COUNTER = Counter('smart_confidant_model_selections_total',
'Count of model selections',
['model_name', 'model_type'])
TOKEN_COUNT = Summary('smart_confidant_tokens_generated', 'Number of tokens generated per response')
CONVERSATION_LENGTH = Gauge('smart_confidant_conversation_length', 'Number of messages in current conversation')
ERROR_BY_TYPE = Counter('smart_confidant_errors_by_type_total',
'Count of errors by type',
['error_type'])
# ============================================================================
# Asset Loading & Theme Configuration
# ============================================================================
# Load background image as base64 data URL for CSS injection
ASSETS_DIR = Path(__file__).parent / "assets"
BACKGROUND_IMAGE_PATH = ASSETS_DIR / "confidant_pattern.png"
try:
with open(BACKGROUND_IMAGE_PATH, "rb") as _img_f:
_encoded_img = base64.b64encode(_img_f.read()).decode("ascii")
BACKGROUND_DATA_URL = f"data:image/png;base64,{_encoded_img}"
log_debug("Background image loaded successfully")
except Exception as e:
log_debug(f"Error loading background image: {e}", "ERROR")
BACKGROUND_DATA_URL = ""
class TransparentTheme(Base):
"""Custom Gradio theme with transparent body background to show tiled image."""
def __init__(self):
super().__init__()
super().set(
body_background_fill="*neutral_950",
body_background_fill_dark="*neutral_950",
)
# Custom CSS for dark theme with tiled background image
# Uses aggressive selectors to override Gradio's default styling
fancy_css = f"""
/* Tiled background image on page body */
body {{
background-image: url('{BACKGROUND_DATA_URL}') !important;
background-repeat: repeat !important;
background-size: auto !important;
background-attachment: fixed !important;
background-color: #1a1a1a !important;
}}
/* Make Gradio wrapper divs transparent to show background */
gradio-app,
.gradio-container,
.gradio-container > div,
.gradio-container > div > div,
.main,
.contain,
[class*="svelte"] > div,
div[class*="wrap"]:not(.gr-button):not([class*="input"]):not([class*="textbox"]):not([class*="bubble"]):not([class*="message"]),
div[class*="container"]:not([class*="input"]):not([class*="button"]) {{
background: transparent !important;
background-color: transparent !important;
background-image: none !important;
}}
/* Center and constrain main container */
.gradio-container {{
max-width: 700px !important;
margin: 0 auto !important;
padding: 20px !important;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1) !important;
border-radius: 10px !important;
font-family: 'Arial', sans-serif !important;
}}
/* Green title banner */
#title {{
text-align: center !important;
font-size: 2em !important;
margin-bottom: 20px !important;
color: #ffffff !important;
background-color: #4CAF50 !important;
padding: 20px !important;
border-radius: 10px !important;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.3) !important;
}}
/* Dark grey backgrounds for chatbot and settings components */
.block.svelte-12cmxck {{
background-color: rgba(60, 60, 60, 0.95) !important;
border-radius: 10px !important;
}}
div[class*="bubble-wrap"],
div[class*="message-wrap"] {{
background-color: rgba(60, 60, 60, 0.95) !important;
border-radius: 10px !important;
padding: 15px !important;
}}
.label-wrap,
div[class*="accordion"] {{
background-color: rgba(60, 60, 60, 0.95) !important;
border-radius: 10px !important;
}}
/* White text for readability on dark backgrounds */
.block.svelte-12cmxck,
.block.svelte-12cmxck *,
div[class*="bubble-wrap"] *,
div[class*="message-wrap"] *,
.label-wrap,
.label-wrap * {{
color: #ffffff !important;
}}
/* Green buttons with hover effect */
.gr-button,
button {{
background-color: #4CAF50 !important;
background-image: none !important;
color: white !important;
border: none !important;
border-radius: 5px !important;
padding: 10px 20px !important;
cursor: pointer !important;
transition: background-color 0.3s ease !important;
}}
.gr-button:hover,
button:hover {{
background-color: #45a049 !important;
}}
.gr-slider input {{
color: #4CAF50 !important;
}}
"""
# ============================================================================
# Chat Response Handler
# ============================================================================
def respond(
message,
history: list[dict[str, str]],
system_message,
max_tokens,
temperature,
top_p,
selected_model: str,
):
"""
Handle chat responses using either local transformers models or HuggingFace API.
Args:
message: User's input message
history: List of previous messages in conversation
system_message: System prompt to guide model behavior
max_tokens: Maximum tokens to generate
temperature: Sampling temperature (higher = more random)
top_p: Nucleus sampling threshold
selected_model: Model identifier with "(local)" or "(api)" suffix
Yields:
str: Generated response text or error message
"""
global pipe
# Prometheus metrics: Track request start
REQUEST_COUNTER.inc()
start_time = time.perf_counter()
try:
log_debug(f"New message received: '{message[:50]}...'")
log_debug(f"Selected model: {selected_model}")
log_debug(f"Parameters - max_tokens: {max_tokens}, temp: {temperature}, top_p: {top_p}")
# Build complete message history with system prompt
messages = [{"role": "system", "content": system_message}]
messages.extend(history)
messages.append({"role": "user", "content": message})
log_debug(f"Message history length: {len(messages)}")
# Parse model type and name from selection
is_local = selected_model.endswith("(local)")
model_name = selected_model.replace(" (local)", "").replace(" (api)", "")
# Prometheus metrics: Track model selection and conversation length
model_type = "local" if is_local else "api"
MODEL_SELECTION_COUNTER.labels(model_name=model_name, model_type=model_type).inc()
CONVERSATION_LENGTH.set(len(messages))
response = ""
if is_local:
# ===== LOCAL MODEL PATH =====
log_debug(f"Using LOCAL mode with model: {model_name}")
try:
from transformers import pipeline
import torch
log_debug("Transformers imported successfully")
# Load or reuse cached pipeline
if pipe is None or pipe.model.name_or_path != model_name:
log_debug(f"Loading model pipeline for: {model_name}")
pipe = pipeline("text-generation", model=model_name)
log_debug("Model pipeline loaded successfully")
else:
log_debug("Using cached model pipeline")
# Format conversation as plain text prompt
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
log_debug(f"Prompt length: {len(prompt)} characters")
# Run inference
log_debug("Starting inference...")
outputs = pipe(
prompt,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
)
log_debug("Inference completed")
# Extract new tokens only (strip original prompt)
response = outputs[0]["generated_text"][len(prompt):]
log_debug(f"Response length: {len(response)} characters")
# Prometheus metrics: Track success and approximate token count
SUCCESSFUL_REQUESTS.inc()
TOKEN_COUNT.observe(len(response.split())) # Approximate token count using word count
yield response.strip()
except ImportError as e:
# Prometheus metrics: Track error
FAILED_REQUESTS.inc()
ERROR_BY_TYPE.labels(error_type="import_error").inc()
error_msg = f"Import error: {str(e)}"
log_debug(error_msg, "ERROR")
log_debug(traceback.format_exc(), "ERROR")
yield f"❌ Import Error: {str(e)}\n\nPlease check log.txt for details."
except Exception as e:
# Prometheus metrics: Track error
FAILED_REQUESTS.inc()
ERROR_BY_TYPE.labels(error_type="local_model_error").inc()
error_msg = f"Local model error: {str(e)}"
log_debug(error_msg, "ERROR")
log_debug(traceback.format_exc(), "ERROR")
yield f"❌ Local Model Error: {str(e)}\n\nPlease check log.txt for details."
else:
# ===== API MODEL PATH =====
log_debug(f"Using API mode with model: {model_name}")
try:
# Check for HuggingFace API token
hf_token = os.environ.get("HF_TOKEN", None)
if hf_token:
log_debug("HF_TOKEN found in environment")
else:
log_debug("No HF_TOKEN in environment - API call will likely fail", "WARN")
# Create HuggingFace Inference client
log_debug("Creating InferenceClient...")
client = InferenceClient(
api_key=hf_token,
)
log_debug("InferenceClient created successfully")
# Call chat completion API
log_debug("Starting chat completion...")
completion = client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
response = completion.choices[0].message.content
log_debug(f"Completion received. Response length: {len(response)} characters")
# Prometheus metrics: Track success and approximate token count
SUCCESSFUL_REQUESTS.inc()
TOKEN_COUNT.observe(len(response.split())) # Approximate token count using word count
yield response
except Exception as e:
# Prometheus metrics: Track error
FAILED_REQUESTS.inc()
ERROR_BY_TYPE.labels(error_type="api_error").inc()
error_msg = f"API error: {str(e)}"
log_debug(error_msg, "ERROR")
log_debug(traceback.format_exc(), "ERROR")
yield f"❌ API Error: {str(e)}\n\nPlease check log.txt for details."
except Exception as e:
# Prometheus metrics: Track error
FAILED_REQUESTS.inc()
ERROR_BY_TYPE.labels(error_type="unexpected_error").inc()
error_msg = f"Unexpected error in respond function: {str(e)}"
log_debug(error_msg, "ERROR")
log_debug(traceback.format_exc(), "ERROR")
yield f"❌ Unexpected Error: {str(e)}\n\nPlease check log.txt for details."
finally:
# Prometheus metrics: Record request duration
REQUEST_DURATION.observe(time.perf_counter() - start_time)
# ============================================================================
# Gradio UI Definition
# ============================================================================
# Allow Gradio to serve static files from assets directory (requires absolute path)
ASSETS_DIR_ABSOLUTE = str(Path(__file__).parent / "assets")
gr.set_static_paths(paths=[ASSETS_DIR_ABSOLUTE])
with gr.Blocks(theme=TransparentTheme(), css=fancy_css) as demo:
# Title banner
gr.Markdown(f"<h1 id='title' style='text-align: center;'>{TITLE}</h1>")
# Chatbot component with custom avatar icons (using forward slashes for web serving)
# Gradio serves files via HTTP URLs which require forward slashes, not Windows backslashes
MONSTER_ICON = str((ASSETS_DIR / "monster_icon.png").as_posix())
BOT_ICON = str((ASSETS_DIR / "smart_confidant_icon.png").as_posix())
log_debug(f"Monster icon path: {MONSTER_ICON}")
log_debug(f"Bot icon path: {BOT_ICON}")
chatbot = gr.Chatbot(
type="messages",
avatar_images=(MONSTER_ICON, BOT_ICON)
)
# Collapsible settings panel
with gr.Accordion("⚙️ Additional Settings", open=False):
system_message = gr.Textbox(value=DEFAULT_SYSTEM_MESSAGE, label="System message")
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
selected_model = gr.Radio(choices=MODEL_OPTIONS, label="Select Model", value=MODEL_OPTIONS[1])
# Wire up chat interface with response handler
gr.ChatInterface(
fn=respond,
chatbot=chatbot,
additional_inputs=[
system_message,
max_tokens,
temperature,
top_p,
selected_model,
],
type="messages",
)
# ============================================================================
# Application Entry Point
# ============================================================================
if __name__ == "__main__":
log_debug("="*50)
log_debug("Smart Confidant Application Starting")
log_debug(f"Available models: {MODEL_OPTIONS}")
log_debug(f"HF_TOKEN present: {'Yes' if os.environ.get('HF_TOKEN') else 'No'}")
log_debug("="*50)
# Start Prometheus metrics server on port 8000
log_debug("Starting Prometheus metrics server on port 8000")
start_http_server(8000)
log_debug("Prometheus metrics server started - available at http://0.0.0.0:8000/metrics")
# Launch on all interfaces for VM/container deployment, with Gradio share link
demo.launch(server_name="0.0.0.0", server_port=8012, share=True, allowed_paths=[ASSETS_DIR_ABSOLUTE])