xstress-api / app.py
gaidasalsaa's picture
Added Dockerfile
e73b762
raw
history blame
9.98 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
import requests
import torch
from transformers import AutoTokenizer, BertForSequenceClassification
from huggingface_hub import hf_hub_download
import os
import gc
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# -----------------------------
# CONFIG
# -----------------------------
HF_MODEL_REPO = "gaidasalsaa/indobertweet-xstress-model"
BASE_MODEL = "indolem/indobertweet-base-uncased"
PT_FILE = "best_indobertweet.pth"
BEARER_TOKEN = os.getenv(
"TWITTER_BEARER_TOKEN",
"AAAAAAAAAAAAAAAAAAAAAOGp3AEAAAAAMEaOafsh1pNGVFrK%2BN2atq0Cba4%3DE2Gw0MDFfJ1bE4veBIIxhOUqbaqQKOqRxMhGybH4FfOETDNpow"
)
# -----------------------------
# FASTAPI APP
# -----------------------------
app = FastAPI(
title="Stress Detection API",
description="Detect stress levels from X(Twitter) user posts using IndoBERTweet",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables
model = None
tokenizer = None
device = None
model_loaded = False
# -----------------------------
# MODELS
# -----------------------------
class StressResponse(BaseModel):
message: str
data: Optional[dict] = None
class HealthResponse(BaseModel):
status: str
model_loaded: bool
device: Optional[str] = None
# -----------------------------
# MODEL LOADING
# -----------------------------
def load_model_once():
"""Load model only once at startup"""
global model, tokenizer, device, model_loaded
if model_loaded:
logger.info("Model already loaded, skipping...")
return
try:
logger.info("πŸ”„ Starting model loading...")
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"πŸ“± Using device: {device}")
# Load tokenizer
logger.info("πŸ“ Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
logger.info("βœ… Tokenizer loaded")
# Download model weights
logger.info(f"⬇️ Downloading {PT_FILE}...")
model_path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename=PT_FILE
)
logger.info(f"βœ… Model file downloaded: {model_path}")
# Load base model
logger.info("🧠 Loading base model architecture...")
model = BertForSequenceClassification.from_pretrained(
BASE_MODEL,
num_labels=2,
low_cpu_mem_usage=True,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
)
logger.info("βœ… Base model loaded")
# Load fine-tuned weights
logger.info("πŸ”§ Loading fine-tuned weights...")
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict, strict=False)
logger.info("βœ… Weights loaded")
# Move to device and set eval mode
model.to(device)
model.eval()
logger.info(f"βœ… Model moved to {device} and set to eval mode")
# Clear memory
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
model_loaded = True
logger.info("βœ… Model loading complete!")
except Exception as e:
logger.error(f"❌ Failed to load model: {str(e)}")
raise
# -----------------------------
# HELPER FUNCTIONS
# -----------------------------
def get_user_id(username: str):
"""Get Twitter user ID from username"""
url = f"https://api.x.com/2/users/by/username/{username}"
headers = {"Authorization": f"Bearer {BEARER_TOKEN}"}
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
return response.json()["data"]["id"], None
except requests.exceptions.RequestException as e:
logger.error(f"Twitter API error (get_user_id): {str(e)}")
return None, str(e)
except KeyError:
return None, "User not found"
def fetch_tweets(user_id: str, limit: int = 25):
"""Fetch recent tweets from user"""
url = f"https://api.x.com/2/users/{user_id}/tweets"
params = {
"max_results": min(limit, 100), # Twitter API max is 100
"tweet.fields": "id,text,created_at"
}
headers = {"Authorization": f"Bearer {BEARER_TOKEN}"}
try:
response = requests.get(url, headers=headers, params=params, timeout=10)
response.raise_for_status()
tweets = response.json().get("data", [])
return [tweet["text"] for tweet in tweets], None
except requests.exceptions.RequestException as e:
logger.error(f"Twitter API error (fetch_tweets): {str(e)}")
return None, str(e)
def predict_stress(text: str):
"""Predict stress level from text"""
try:
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
).to(device)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)[0]
label = torch.argmax(probs).item()
confidence = float(probs[1])
return label, confidence
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
raise
# -----------------------------
# STARTUP EVENT
# -----------------------------
@app.on_event("startup")
async def startup_event():
"""Load model when app starts"""
logger.info("πŸš€ Application starting...")
load_model_once()
logger.info("βœ… Application ready!")
# -----------------------------
# API ENDPOINTS
# -----------------------------
@app.get("/", response_model=dict)
async def root():
"""Root endpoint with API info"""
return {
"name": "Stress Detection API",
"version": "1.0.0",
"status": "online",
"endpoints": {
"health": "/health",
"analyze": "/analyze/{username}",
"docs": "/docs"
}
}
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
return HealthResponse(
status="healthy" if model_loaded else "loading",
model_loaded=model_loaded,
device=str(device) if device else None
)
@app.get("/analyze/{username}", response_model=StressResponse)
async def analyze_user(username: str):
"""
Analyze stress level from Twitter user's recent tweets
- **username**: Twitter username (without @)
"""
# Ensure model is loaded
if not model_loaded:
logger.warning("Model not loaded yet, loading now...")
load_model_once()
# Remove @ if user included it
username = username.lstrip("@")
logger.info(f"πŸ“Š Analyzing user: @{username}")
# 1. Get user ID
user_id, error = get_user_id(username)
if error:
raise HTTPException(
status_code=404,
detail=f"Failed to fetch user profile: {error}"
)
# 2. Fetch tweets
tweets, error = fetch_tweets(user_id)
if error:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch tweets: {error}"
)
if not tweets:
return StressResponse(
message="No tweets found. User may be protected or has no tweets.",
data=None
)
# 3. Analyze each tweet
labels = []
confidences = []
for i, tweet in enumerate(tweets):
try:
label, confidence = predict_stress(tweet)
labels.append(label)
confidences.append(confidence)
logger.info(f"Tweet {i+1}/{len(tweets)}: label={label}, confidence={confidence:.2f}")
except Exception as e:
logger.warning(f"Skipping tweet {i+1} due to error: {str(e)}")
continue
if not labels:
raise HTTPException(
status_code=500,
detail="Failed to analyze any tweets"
)
# 4. Calculate statistics
stress_percentage = round(sum(labels) / len(labels) * 100, 2)
avg_confidence = round(sum(confidences) / len(confidences) * 100, 2)
# Determine stress status
if stress_percentage <= 25:
status = 0
status_text = "Low Stress"
elif stress_percentage <= 50:
status = 1
status_text = "Medium Stress"
elif stress_percentage <= 75:
status = 2
status_text = "High Stress"
else:
status = 3
status_text = "Very High Stress"
logger.info(f"βœ… Analysis complete: {stress_percentage}% stress ({status_text})")
return StressResponse(
message="Analysis successful",
data={
"username": username,
"total_tweets": len(tweets),
"analyzed_tweets": len(labels),
"stress_level": stress_percentage,
"stress_status": status,
"stress_status_text": status_text,
"average_confidence": avg_confidence
}
)
# -----------------------------
# ERROR HANDLERS
# -----------------------------
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
logger.error(f"Unhandled exception: {str(exc)}")
return StressResponse(
message=f"Internal server error: {str(exc)}",
data=None
)
# -----------------------------
# RUN (for local testing only)
# -----------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)