Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import requests | |
| import torch | |
| from transformers import AutoTokenizer, BertForSequenceClassification | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| import gc | |
| import logging | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ----------------------------- | |
| # CONFIG | |
| # ----------------------------- | |
| HF_MODEL_REPO = "gaidasalsaa/indobertweet-xstress-model" | |
| BASE_MODEL = "indolem/indobertweet-base-uncased" | |
| PT_FILE = "best_indobertweet.pth" | |
| BEARER_TOKEN = os.getenv( | |
| "TWITTER_BEARER_TOKEN", | |
| "AAAAAAAAAAAAAAAAAAAAAOGp3AEAAAAAMEaOafsh1pNGVFrK%2BN2atq0Cba4%3DE2Gw0MDFfJ1bE4veBIIxhOUqbaqQKOqRxMhGybH4FfOETDNpow" | |
| ) | |
| # ----------------------------- | |
| # FASTAPI APP | |
| # ----------------------------- | |
| app = FastAPI( | |
| title="Stress Detection API", | |
| description="Detect stress levels from X(Twitter) user posts using IndoBERTweet", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables | |
| model = None | |
| tokenizer = None | |
| device = None | |
| model_loaded = False | |
| # ----------------------------- | |
| # MODELS | |
| # ----------------------------- | |
| class StressResponse(BaseModel): | |
| message: str | |
| data: Optional[dict] = None | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| device: Optional[str] = None | |
| # ----------------------------- | |
| # MODEL LOADING | |
| # ----------------------------- | |
| def load_model_once(): | |
| """Load model only once at startup""" | |
| global model, tokenizer, device, model_loaded | |
| if model_loaded: | |
| logger.info("Model already loaded, skipping...") | |
| return | |
| try: | |
| logger.info("π Starting model loading...") | |
| # Set device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"π± Using device: {device}") | |
| # Load tokenizer | |
| logger.info("π Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| logger.info("β Tokenizer loaded") | |
| # Download model weights | |
| logger.info(f"β¬οΈ Downloading {PT_FILE}...") | |
| model_path = hf_hub_download( | |
| repo_id=HF_MODEL_REPO, | |
| filename=PT_FILE | |
| ) | |
| logger.info(f"β Model file downloaded: {model_path}") | |
| # Load base model | |
| logger.info("π§ Loading base model architecture...") | |
| model = BertForSequenceClassification.from_pretrained( | |
| BASE_MODEL, | |
| num_labels=2, | |
| low_cpu_mem_usage=True, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
| ) | |
| logger.info("β Base model loaded") | |
| # Load fine-tuned weights | |
| logger.info("π§ Loading fine-tuned weights...") | |
| state_dict = torch.load(model_path, map_location=device) | |
| model.load_state_dict(state_dict, strict=False) | |
| logger.info("β Weights loaded") | |
| # Move to device and set eval mode | |
| model.to(device) | |
| model.eval() | |
| logger.info(f"β Model moved to {device} and set to eval mode") | |
| # Clear memory | |
| gc.collect() | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| model_loaded = True | |
| logger.info("β Model loading complete!") | |
| except Exception as e: | |
| logger.error(f"β Failed to load model: {str(e)}") | |
| raise | |
| # ----------------------------- | |
| # HELPER FUNCTIONS | |
| # ----------------------------- | |
| def get_user_id(username: str): | |
| """Get Twitter user ID from username""" | |
| url = f"https://api.x.com/2/users/by/username/{username}" | |
| headers = {"Authorization": f"Bearer {BEARER_TOKEN}"} | |
| try: | |
| response = requests.get(url, headers=headers, timeout=10) | |
| response.raise_for_status() | |
| return response.json()["data"]["id"], None | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Twitter API error (get_user_id): {str(e)}") | |
| return None, str(e) | |
| except KeyError: | |
| return None, "User not found" | |
| def fetch_tweets(user_id: str, limit: int = 25): | |
| """Fetch recent tweets from user""" | |
| url = f"https://api.x.com/2/users/{user_id}/tweets" | |
| params = { | |
| "max_results": min(limit, 100), # Twitter API max is 100 | |
| "tweet.fields": "id,text,created_at" | |
| } | |
| headers = {"Authorization": f"Bearer {BEARER_TOKEN}"} | |
| try: | |
| response = requests.get(url, headers=headers, params=params, timeout=10) | |
| response.raise_for_status() | |
| tweets = response.json().get("data", []) | |
| return [tweet["text"] for tweet in tweets], None | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Twitter API error (fetch_tweets): {str(e)}") | |
| return None, str(e) | |
| def predict_stress(text: str): | |
| """Predict stress level from text""" | |
| try: | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=128 | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=1)[0] | |
| label = torch.argmax(probs).item() | |
| confidence = float(probs[1]) | |
| return label, confidence | |
| except Exception as e: | |
| logger.error(f"Prediction error: {str(e)}") | |
| raise | |
| # ----------------------------- | |
| # STARTUP EVENT | |
| # ----------------------------- | |
| async def startup_event(): | |
| """Load model when app starts""" | |
| logger.info("π Application starting...") | |
| load_model_once() | |
| logger.info("β Application ready!") | |
| # ----------------------------- | |
| # API ENDPOINTS | |
| # ----------------------------- | |
| async def root(): | |
| """Root endpoint with API info""" | |
| return { | |
| "name": "Stress Detection API", | |
| "version": "1.0.0", | |
| "status": "online", | |
| "endpoints": { | |
| "health": "/health", | |
| "analyze": "/analyze/{username}", | |
| "docs": "/docs" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return HealthResponse( | |
| status="healthy" if model_loaded else "loading", | |
| model_loaded=model_loaded, | |
| device=str(device) if device else None | |
| ) | |
| async def analyze_user(username: str): | |
| """ | |
| Analyze stress level from Twitter user's recent tweets | |
| - **username**: Twitter username (without @) | |
| """ | |
| # Ensure model is loaded | |
| if not model_loaded: | |
| logger.warning("Model not loaded yet, loading now...") | |
| load_model_once() | |
| # Remove @ if user included it | |
| username = username.lstrip("@") | |
| logger.info(f"π Analyzing user: @{username}") | |
| # 1. Get user ID | |
| user_id, error = get_user_id(username) | |
| if error: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Failed to fetch user profile: {error}" | |
| ) | |
| # 2. Fetch tweets | |
| tweets, error = fetch_tweets(user_id) | |
| if error: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to fetch tweets: {error}" | |
| ) | |
| if not tweets: | |
| return StressResponse( | |
| message="No tweets found. User may be protected or has no tweets.", | |
| data=None | |
| ) | |
| # 3. Analyze each tweet | |
| labels = [] | |
| confidences = [] | |
| for i, tweet in enumerate(tweets): | |
| try: | |
| label, confidence = predict_stress(tweet) | |
| labels.append(label) | |
| confidences.append(confidence) | |
| logger.info(f"Tweet {i+1}/{len(tweets)}: label={label}, confidence={confidence:.2f}") | |
| except Exception as e: | |
| logger.warning(f"Skipping tweet {i+1} due to error: {str(e)}") | |
| continue | |
| if not labels: | |
| raise HTTPException( | |
| status_code=500, | |
| detail="Failed to analyze any tweets" | |
| ) | |
| # 4. Calculate statistics | |
| stress_percentage = round(sum(labels) / len(labels) * 100, 2) | |
| avg_confidence = round(sum(confidences) / len(confidences) * 100, 2) | |
| # Determine stress status | |
| if stress_percentage <= 25: | |
| status = 0 | |
| status_text = "Low Stress" | |
| elif stress_percentage <= 50: | |
| status = 1 | |
| status_text = "Medium Stress" | |
| elif stress_percentage <= 75: | |
| status = 2 | |
| status_text = "High Stress" | |
| else: | |
| status = 3 | |
| status_text = "Very High Stress" | |
| logger.info(f"β Analysis complete: {stress_percentage}% stress ({status_text})") | |
| return StressResponse( | |
| message="Analysis successful", | |
| data={ | |
| "username": username, | |
| "total_tweets": len(tweets), | |
| "analyzed_tweets": len(labels), | |
| "stress_level": stress_percentage, | |
| "stress_status": status, | |
| "stress_status_text": status_text, | |
| "average_confidence": avg_confidence | |
| } | |
| ) | |
| # ----------------------------- | |
| # ERROR HANDLERS | |
| # ----------------------------- | |
| async def global_exception_handler(request, exc): | |
| logger.error(f"Unhandled exception: {str(exc)}") | |
| return StressResponse( | |
| message=f"Internal server error: {str(exc)}", | |
| data=None | |
| ) | |
| # ----------------------------- | |
| # RUN (for local testing only) | |
| # ----------------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |