xstress-api / app.py
gaidasalsaa's picture
add all
6c55ffe
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional
import requests
import torch
from transformers import AutoTokenizer, BertForSequenceClassification
from huggingface_hub import hf_hub_download
import logging
logger = logging.getLogger("app")
logging.basicConfig(level=logging.INFO)
# =====================================================
# CONFIG
# =====================================================
HF_MODEL_REPO = "gaidasalsaa/indobertweet-xstress-model"
BASE_MODEL = "indolem/indobertweet-base-uncased"
PT_FILE = "model_indobertweet.pth"
BEARER_TOKEN = "AAAAAAAAAAAAAAAAAAAAADXr5gEAAAAAnQZgkYRrC4iM5WTblBxDyt58oj8%3DriQZkuHuvRL6Suc3rmDhD3umqbHaxwim2Tfb34rfQpnKqf9Xhd"
# =====================================================
# GLOBAL MODEL STORAGE
# =====================================================
tokenizer = None
model = None
# =====================================================
# LOAD MODEL
# =====================================================
def load_model_once():
global tokenizer, model
if tokenizer is not None and model is not None:
logger.info("Model already loaded.")
return
logger.info("Starting model loading...")
device = "cpu"
logger.info(f"Using device: {device}")
# ---- load tokenizer ----
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
logger.info("Tokenizer loaded")
# ---- download .pth ----
logger.info("Downloading best_indobertweet.pth...")
model_path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename=PT_FILE,
)
logger.info(f"Model file downloaded: {model_path}")
logger.info("Loading base model architecture...")
model = BertForSequenceClassification.from_pretrained(
BASE_MODEL,
num_labels=2,
)
logger.info("Loading fine-tuned weights (.pth)...")
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
logger.info("Weights loaded successfully")
model.to(device)
model.eval()
logger.info("MODEL READY")
# =====================================================
# FASTAPI
# =====================================================
app = FastAPI(title="Stress Detection API")
@app.on_event("startup")
def startup_event():
logger.info("Starting model loading on startup...")
load_model_once()
class StressResponse(BaseModel):
message: str
data: Optional[dict] = None
# =====================================================
# TWITTER API
# =====================================================
def get_user_id(username):
url = f"https://api.x.com/2/users/by/username/{username}"
headers = {"Authorization": f"Bearer {BEARER_TOKEN}"}
r = requests.get(url, headers=headers)
if r.status_code != 200:
return None, r.json()
return r.json()["data"]["id"], r.json()
def fetch_tweets(user_id, limit=25):
url = f"https://api.x.com/2/users/{user_id}/tweets"
params = {"max_results": limit, "tweet.fields": "id,text,created_at"}
headers = {"Authorization": f"Bearer {BEARER_TOKEN}"}
r = requests.get(url, headers=headers, params=params)
if r.status_code != 200:
return None, r.json()
tweets = r.json().get("data", [])
return [t["text"] for t in tweets], r.json()
# =====================================================
# KEYWORD EXTRACTION
# =====================================================
def extract_keywords(tweets):
stress_words = [
"gelisah","cemas","tidur","takut","hati",
"resah","sampe","tenang","suka","mulu",
"sedih","ngerasa","gimana","gatau",
"perasaan","nangis","deg","khawatir",
"pikiran","harap","gabisa","bener","pengen",
"sakit","susah","bangun","biar","jam","kaya",
"bingung","mikir","tuhan","mikirin",
"bawaannya","marah","tbtb","anjir","cape",
"panik","enak","kali","pusing","semoga",
"kadang","langsung","kemarin","tugas",
"males"
]
found = set()
for t in tweets:
lower = t.lower()
for word in stress_words:
if word in lower:
found.add(word)
return list(found)
# =====================================================
# INFERENCE
# =====================================================
def predict_stress(text):
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)[0]
label = torch.argmax(probs).item()
return label, float(probs[1])
# =====================================================
# API ROUTE
# =====================================================
@app.get("/analyze/{username}", response_model=StressResponse)
def analyze(username: str):
user_id, _ = get_user_id(username)
if user_id is None:
return StressResponse(message="Failed to fetch profile", data=None)
tweets, _ = fetch_tweets(user_id)
if not tweets:
return StressResponse(message="No tweets available", data=None)
labels = [predict_stress(t)[0] for t in tweets]
stress_percentage = round(sum(labels) / len(labels) * 100, 2)
if stress_percentage <= 25:
status = 0
elif stress_percentage <= 50:
status = 1
elif stress_percentage <= 75:
status = 2
else:
status = 3
keywords = extract_keywords(tweets)
return StressResponse(
message="Analysis complete",
data={
"username": username,
"total_tweets": len(tweets),
"stress_level": stress_percentage,
"keywords": keywords,
"stress_status": status
}
)