gaidasalsaa commited on
Commit
e73b762
·
1 Parent(s): b38b1f8

Added Dockerfile

Browse files
Files changed (3) hide show
  1. Dockerfile +34 -0
  2. README.md +14 -2
  3. app.py +214 -132
Dockerfile ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use official Python runtime
2
+ FROM python:3.10-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ build-essential \
10
+ curl \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy requirements first (for better caching)
14
+ COPY requirements.txt .
15
+
16
+ # Install Python dependencies
17
+ RUN pip install --no-cache-dir --upgrade pip && \
18
+ pip install --no-cache-dir -r requirements.txt
19
+
20
+ # Copy application code
21
+ COPY app.py .
22
+
23
+ # Expose port 7860 (required by HF Spaces)
24
+ EXPOSE 7860
25
+
26
+ # Set environment variables
27
+ ENV PYTHONUNBUFFERED=1
28
+
29
+ # Health check
30
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
31
+ CMD curl -f http://localhost:7860/health || exit 1
32
+
33
+ # Run the application
34
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
README.md CHANGED
@@ -7,6 +7,18 @@ sdk: docker
7
  pinned: false
8
  ---
9
 
10
- # Stress Detection API
11
 
12
- Deteksi tingkat stress dari postingan Twitter menggunakan IndoBERTweet.
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  pinned: false
8
  ---
9
 
10
+ # Twitter Stress Detection API
11
 
12
+ API untuk mendeteksi tingkat stress dari postingan Twitter menggunakan model IndoBERTweet.
13
+
14
+ ## Endpoints
15
+
16
+ - `GET /` - Info API
17
+ - `GET /health` - Health check
18
+ - `GET /analyze/{username}` - Analyze user stress level
19
+ - `GET /docs` - Interactive API documentation
20
+
21
+ ## Usage
22
+ ```bash
23
+ curl https://your-space.hf.space/analyze/username
24
+ ```
app.py CHANGED
@@ -1,4 +1,5 @@
1
- from fastapi import FastAPI
 
2
  from pydantic import BaseModel
3
  from typing import Optional
4
  import requests
@@ -7,6 +8,11 @@ from transformers import AutoTokenizer, BertForSequenceClassification
7
  from huggingface_hub import hf_hub_download
8
  import os
9
  import gc
 
 
 
 
 
10
 
11
  # -----------------------------
12
  # CONFIG
@@ -21,217 +27,281 @@ BEARER_TOKEN = os.getenv(
21
  )
22
 
23
  # -----------------------------
24
- # FASTAPI (Initialize FIRST)
25
  # -----------------------------
26
  app = FastAPI(
27
  title="Stress Detection API",
28
- description="Detect stress levels from X(Twitter) user posts",
29
- version="1.0.0"
 
 
 
 
 
 
 
 
 
 
 
30
  )
31
 
32
- # Global variables untuk lazy loading
33
  model = None
34
  tokenizer = None
35
  device = None
 
36
 
 
 
 
37
  class StressResponse(BaseModel):
38
  message: str
39
  data: Optional[dict] = None
40
 
 
 
 
 
41
 
42
  # -----------------------------
43
- # LAZY LOAD MODEL
44
  # -----------------------------
45
  def load_model_once():
46
- """Load model hanya sekali saat pertama kali dipanggil"""
47
- global model, tokenizer, device
48
-
49
- if model is not None:
50
- return # Sudah di-load
51
-
52
- print("Loading model (first time only)...")
53
-
54
- # Set device
55
- device = "cuda" if torch.cuda.is_available() else "cpu"
56
- print(f"Using device: {device}")
57
 
58
- # 1. Load tokenizer
59
- print("Loading tokenizer...")
60
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
61
 
62
- # 2. Download .pth file
63
- print(f"⬇Downloading {PT_FILE}...")
64
- model_path = hf_hub_download(
65
- repo_id=HF_MODEL_REPO,
66
- filename=PT_FILE
67
- )
68
- print(f"Downloaded to: {model_path}")
69
-
70
- # 3. Load base model dengan optimasi memory
71
- print("Loading base model...")
72
- model = BertForSequenceClassification.from_pretrained(
73
- BASE_MODEL,
74
- num_labels=2,
75
- low_cpu_mem_usage=True,
76
- torch_dtype=torch.float16 if device == "cuda" else torch.float32
77
- )
78
-
79
- # 4. Load fine-tuned weights
80
- print("Loading fine-tuned weights...")
81
- state_dict = torch.load(model_path, map_location=device)
82
- model.load_state_dict(state_dict, strict=False)
83
-
84
- # 5. Move to device dan set eval mode
85
- model.to(device)
86
- model.eval()
87
-
88
- # 6. Clear cache
89
- gc.collect()
90
- if device == "cuda":
91
- torch.cuda.empty_cache()
92
-
93
- print("Model loaded successfully!")
94
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  # -----------------------------
97
  # HELPER FUNCTIONS
98
  # -----------------------------
99
- def get_user_id(username):
100
  """Get Twitter user ID from username"""
101
  url = f"https://api.x.com/2/users/by/username/{username}"
102
  headers = {"Authorization": f"Bearer {BEARER_TOKEN}"}
103
 
104
  try:
105
- r = requests.get(url, headers=headers, timeout=10)
106
- r.raise_for_status()
107
- return r.json()["data"]["id"], None
108
- except Exception as e:
109
- return None, {"error": str(e)}
110
-
 
 
111
 
112
- def fetch_tweets(user_id, limit=25):
113
- """Fetch user's recent tweets"""
114
  url = f"https://api.x.com/2/users/{user_id}/tweets"
115
- params = {"max_results": limit, "tweet.fields": "id,text,created_at"}
 
 
 
116
  headers = {"Authorization": f"Bearer {BEARER_TOKEN}"}
117
 
118
  try:
119
- r = requests.get(url, headers=headers, params=params, timeout=10)
120
- r.raise_for_status()
121
- tweets = r.json().get("data", [])
122
- return [t["text"] for t in tweets], None
123
- except Exception as e:
124
- return None, {"error": str(e)}
125
-
126
 
127
- def predict_stress(text):
128
  """Predict stress level from text"""
129
- inputs = tokenizer(
130
- text,
131
- return_tensors="pt",
132
- truncation=True,
133
- padding=True,
134
- max_length=128
135
- ).to(device)
136
-
137
- with torch.no_grad():
138
- outputs = model(**inputs)
139
- probs = torch.softmax(outputs.logits, dim=1)[0]
140
-
141
- label = torch.argmax(probs).item()
142
- return label, float(probs[1])
143
-
 
 
 
 
 
144
 
145
  # -----------------------------
146
- # API ENDPOINTS
147
  # -----------------------------
148
  @app.on_event("startup")
149
  async def startup_event():
150
- """Load model saat aplikasi start"""
151
- print("Starting application...")
152
  load_model_once()
153
- print("Application ready!")
154
-
155
 
156
- @app.get("/")
157
- def root():
158
- """Health check endpoint"""
 
 
 
159
  return {
 
 
160
  "status": "online",
161
- "message": "Stress Detection API is running",
162
- "model_loaded": model is not None
163
- }
164
-
165
-
166
- @app.get("/health")
167
- def health():
168
- """Detailed health check"""
169
- return {
170
- "status": "healthy",
171
- "model_loaded": model is not None,
172
- "device": str(device) if device else "not loaded",
173
- "tokenizer_loaded": tokenizer is not None
174
  }
175
 
 
 
 
 
 
 
 
 
176
 
177
  @app.get("/analyze/{username}", response_model=StressResponse)
178
- def analyze(username: str):
179
- """Analyze stress level from user's tweets"""
 
 
 
 
180
 
181
- # Pastikan model sudah loaded
182
- if model is None:
 
183
  load_model_once()
184
 
 
 
 
 
 
185
  # 1. Get user ID
186
  user_id, error = get_user_id(username)
187
  if error:
188
- return StressResponse(
189
- message=f"Failed to fetch user profile: {error.get('error', 'Unknown error')}",
190
- data=None
191
  )
192
 
193
  # 2. Fetch tweets
194
  tweets, error = fetch_tweets(user_id)
195
  if error:
196
- return StressResponse(
197
- message=f"Failed to fetch tweets: {error.get('error', 'Unknown error')}",
198
- data=None
199
  )
200
 
201
  if not tweets:
202
  return StressResponse(
203
- message="User has no tweets or account is protected.",
204
  data=None
205
  )
206
 
207
- # 3. Predict stress for each tweet
208
  labels = []
209
- for tweet in tweets:
 
 
210
  try:
211
- label, _ = predict_stress(tweet)
212
  labels.append(label)
 
 
213
  except Exception as e:
214
- print(f"Skipping tweet due to error: {e}")
215
  continue
216
 
217
  if not labels:
218
- return StressResponse(
219
- message="Failed to analyze tweets.",
220
- data=None
221
  )
222
 
223
- # 4. Calculate stress statistics
224
  stress_percentage = round(sum(labels) / len(labels) * 100, 2)
 
225
 
226
  # Determine stress status
227
  if stress_percentage <= 25:
228
- status = 0 # Low
 
229
  elif stress_percentage <= 50:
230
- status = 1 # Medium
 
231
  elif stress_percentage <= 75:
232
- status = 2 # High
 
233
  else:
234
- status = 3 # Very High
 
 
 
235
 
236
  return StressResponse(
237
  message="Analysis successful",
@@ -240,13 +310,25 @@ def analyze(username: str):
240
  "total_tweets": len(tweets),
241
  "analyzed_tweets": len(labels),
242
  "stress_level": stress_percentage,
243
- "stress_status": status
 
 
244
  }
245
  )
246
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  # -----------------------------
249
- # RUN (untuk local testing)
250
  # -----------------------------
251
  if __name__ == "__main__":
252
  import uvicorn
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from typing import Optional
5
  import requests
 
8
  from huggingface_hub import hf_hub_download
9
  import os
10
  import gc
11
+ import logging
12
+
13
+ # Setup logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
 
17
  # -----------------------------
18
  # CONFIG
 
27
  )
28
 
29
  # -----------------------------
30
+ # FASTAPI APP
31
  # -----------------------------
32
  app = FastAPI(
33
  title="Stress Detection API",
34
+ description="Detect stress levels from X(Twitter) user posts using IndoBERTweet",
35
+ version="1.0.0",
36
+ docs_url="/docs",
37
+ redoc_url="/redoc"
38
+ )
39
+
40
+ # Add CORS middleware
41
+ app.add_middleware(
42
+ CORSMiddleware,
43
+ allow_origins=["*"],
44
+ allow_credentials=True,
45
+ allow_methods=["*"],
46
+ allow_headers=["*"],
47
  )
48
 
49
+ # Global variables
50
  model = None
51
  tokenizer = None
52
  device = None
53
+ model_loaded = False
54
 
55
+ # -----------------------------
56
+ # MODELS
57
+ # -----------------------------
58
  class StressResponse(BaseModel):
59
  message: str
60
  data: Optional[dict] = None
61
 
62
+ class HealthResponse(BaseModel):
63
+ status: str
64
+ model_loaded: bool
65
+ device: Optional[str] = None
66
 
67
  # -----------------------------
68
+ # MODEL LOADING
69
  # -----------------------------
70
  def load_model_once():
71
+ """Load model only once at startup"""
72
+ global model, tokenizer, device, model_loaded
 
 
 
 
 
 
 
 
 
73
 
74
+ if model_loaded:
75
+ logger.info("Model already loaded, skipping...")
76
+ return
77
 
78
+ try:
79
+ logger.info("🔄 Starting model loading...")
80
+
81
+ # Set device
82
+ device = "cuda" if torch.cuda.is_available() else "cpu"
83
+ logger.info(f"📱 Using device: {device}")
84
+
85
+ # Load tokenizer
86
+ logger.info("📝 Loading tokenizer...")
87
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
88
+ logger.info("✅ Tokenizer loaded")
89
+
90
+ # Download model weights
91
+ logger.info(f"⬇️ Downloading {PT_FILE}...")
92
+ model_path = hf_hub_download(
93
+ repo_id=HF_MODEL_REPO,
94
+ filename=PT_FILE
95
+ )
96
+ logger.info(f" Model file downloaded: {model_path}")
97
+
98
+ # Load base model
99
+ logger.info("🧠 Loading base model architecture...")
100
+ model = BertForSequenceClassification.from_pretrained(
101
+ BASE_MODEL,
102
+ num_labels=2,
103
+ low_cpu_mem_usage=True,
104
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
105
+ )
106
+ logger.info("✅ Base model loaded")
107
+
108
+ # Load fine-tuned weights
109
+ logger.info("🔧 Loading fine-tuned weights...")
110
+ state_dict = torch.load(model_path, map_location=device)
111
+ model.load_state_dict(state_dict, strict=False)
112
+ logger.info("✅ Weights loaded")
113
+
114
+ # Move to device and set eval mode
115
+ model.to(device)
116
+ model.eval()
117
+ logger.info(f"✅ Model moved to {device} and set to eval mode")
118
+
119
+ # Clear memory
120
+ gc.collect()
121
+ if device == "cuda":
122
+ torch.cuda.empty_cache()
123
+
124
+ model_loaded = True
125
+ logger.info("✅ Model loading complete!")
126
+
127
+ except Exception as e:
128
+ logger.error(f"❌ Failed to load model: {str(e)}")
129
+ raise
130
 
131
  # -----------------------------
132
  # HELPER FUNCTIONS
133
  # -----------------------------
134
+ def get_user_id(username: str):
135
  """Get Twitter user ID from username"""
136
  url = f"https://api.x.com/2/users/by/username/{username}"
137
  headers = {"Authorization": f"Bearer {BEARER_TOKEN}"}
138
 
139
  try:
140
+ response = requests.get(url, headers=headers, timeout=10)
141
+ response.raise_for_status()
142
+ return response.json()["data"]["id"], None
143
+ except requests.exceptions.RequestException as e:
144
+ logger.error(f"Twitter API error (get_user_id): {str(e)}")
145
+ return None, str(e)
146
+ except KeyError:
147
+ return None, "User not found"
148
 
149
+ def fetch_tweets(user_id: str, limit: int = 25):
150
+ """Fetch recent tweets from user"""
151
  url = f"https://api.x.com/2/users/{user_id}/tweets"
152
+ params = {
153
+ "max_results": min(limit, 100), # Twitter API max is 100
154
+ "tweet.fields": "id,text,created_at"
155
+ }
156
  headers = {"Authorization": f"Bearer {BEARER_TOKEN}"}
157
 
158
  try:
159
+ response = requests.get(url, headers=headers, params=params, timeout=10)
160
+ response.raise_for_status()
161
+ tweets = response.json().get("data", [])
162
+ return [tweet["text"] for tweet in tweets], None
163
+ except requests.exceptions.RequestException as e:
164
+ logger.error(f"Twitter API error (fetch_tweets): {str(e)}")
165
+ return None, str(e)
166
 
167
+ def predict_stress(text: str):
168
  """Predict stress level from text"""
169
+ try:
170
+ inputs = tokenizer(
171
+ text,
172
+ return_tensors="pt",
173
+ truncation=True,
174
+ padding=True,
175
+ max_length=128
176
+ ).to(device)
177
+
178
+ with torch.no_grad():
179
+ outputs = model(**inputs)
180
+ probs = torch.softmax(outputs.logits, dim=1)[0]
181
+
182
+ label = torch.argmax(probs).item()
183
+ confidence = float(probs[1])
184
+
185
+ return label, confidence
186
+ except Exception as e:
187
+ logger.error(f"Prediction error: {str(e)}")
188
+ raise
189
 
190
  # -----------------------------
191
+ # STARTUP EVENT
192
  # -----------------------------
193
  @app.on_event("startup")
194
  async def startup_event():
195
+ """Load model when app starts"""
196
+ logger.info("🚀 Application starting...")
197
  load_model_once()
198
+ logger.info("Application ready!")
 
199
 
200
+ # -----------------------------
201
+ # API ENDPOINTS
202
+ # -----------------------------
203
+ @app.get("/", response_model=dict)
204
+ async def root():
205
+ """Root endpoint with API info"""
206
  return {
207
+ "name": "Stress Detection API",
208
+ "version": "1.0.0",
209
  "status": "online",
210
+ "endpoints": {
211
+ "health": "/health",
212
+ "analyze": "/analyze/{username}",
213
+ "docs": "/docs"
214
+ }
 
 
 
 
 
 
 
 
215
  }
216
 
217
+ @app.get("/health", response_model=HealthResponse)
218
+ async def health_check():
219
+ """Health check endpoint"""
220
+ return HealthResponse(
221
+ status="healthy" if model_loaded else "loading",
222
+ model_loaded=model_loaded,
223
+ device=str(device) if device else None
224
+ )
225
 
226
  @app.get("/analyze/{username}", response_model=StressResponse)
227
+ async def analyze_user(username: str):
228
+ """
229
+ Analyze stress level from Twitter user's recent tweets
230
+
231
+ - **username**: Twitter username (without @)
232
+ """
233
 
234
+ # Ensure model is loaded
235
+ if not model_loaded:
236
+ logger.warning("Model not loaded yet, loading now...")
237
  load_model_once()
238
 
239
+ # Remove @ if user included it
240
+ username = username.lstrip("@")
241
+
242
+ logger.info(f"📊 Analyzing user: @{username}")
243
+
244
  # 1. Get user ID
245
  user_id, error = get_user_id(username)
246
  if error:
247
+ raise HTTPException(
248
+ status_code=404,
249
+ detail=f"Failed to fetch user profile: {error}"
250
  )
251
 
252
  # 2. Fetch tweets
253
  tweets, error = fetch_tweets(user_id)
254
  if error:
255
+ raise HTTPException(
256
+ status_code=500,
257
+ detail=f"Failed to fetch tweets: {error}"
258
  )
259
 
260
  if not tweets:
261
  return StressResponse(
262
+ message="No tweets found. User may be protected or has no tweets.",
263
  data=None
264
  )
265
 
266
+ # 3. Analyze each tweet
267
  labels = []
268
+ confidences = []
269
+
270
+ for i, tweet in enumerate(tweets):
271
  try:
272
+ label, confidence = predict_stress(tweet)
273
  labels.append(label)
274
+ confidences.append(confidence)
275
+ logger.info(f"Tweet {i+1}/{len(tweets)}: label={label}, confidence={confidence:.2f}")
276
  except Exception as e:
277
+ logger.warning(f"Skipping tweet {i+1} due to error: {str(e)}")
278
  continue
279
 
280
  if not labels:
281
+ raise HTTPException(
282
+ status_code=500,
283
+ detail="Failed to analyze any tweets"
284
  )
285
 
286
+ # 4. Calculate statistics
287
  stress_percentage = round(sum(labels) / len(labels) * 100, 2)
288
+ avg_confidence = round(sum(confidences) / len(confidences) * 100, 2)
289
 
290
  # Determine stress status
291
  if stress_percentage <= 25:
292
+ status = 0
293
+ status_text = "Low Stress"
294
  elif stress_percentage <= 50:
295
+ status = 1
296
+ status_text = "Medium Stress"
297
  elif stress_percentage <= 75:
298
+ status = 2
299
+ status_text = "High Stress"
300
  else:
301
+ status = 3
302
+ status_text = "Very High Stress"
303
+
304
+ logger.info(f"✅ Analysis complete: {stress_percentage}% stress ({status_text})")
305
 
306
  return StressResponse(
307
  message="Analysis successful",
 
310
  "total_tweets": len(tweets),
311
  "analyzed_tweets": len(labels),
312
  "stress_level": stress_percentage,
313
+ "stress_status": status,
314
+ "stress_status_text": status_text,
315
+ "average_confidence": avg_confidence
316
  }
317
  )
318
 
319
+ # -----------------------------
320
+ # ERROR HANDLERS
321
+ # -----------------------------
322
+ @app.exception_handler(Exception)
323
+ async def global_exception_handler(request, exc):
324
+ logger.error(f"Unhandled exception: {str(exc)}")
325
+ return StressResponse(
326
+ message=f"Internal server error: {str(exc)}",
327
+ data=None
328
+ )
329
 
330
  # -----------------------------
331
+ # RUN (for local testing only)
332
  # -----------------------------
333
  if __name__ == "__main__":
334
  import uvicorn