# app.py - CORRECTED VERSION from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch from transformers import RobertaTokenizer, RobertaModel import uvicorn from contextlib import asynccontextmanager # Global variables model = None tokenizer = None device = None # THIS MUST MATCH YOUR TRAINING CODE EXACTLY class CodeBERTClassifier(torch.nn.Module): def __init__(self, num_labels=4, dropout=0.3, hidden_size=256): super(CodeBERTClassifier, self).__init__() # Load pre-trained CodeBERT self.codebert = RobertaModel.from_pretrained('microsoft/codebert-base') # Dropout self.dropout = torch.nn.Dropout(dropout) # Multi-layer feedforward network - MUST MATCH TRAINING self.classifier = torch.nn.Sequential( torch.nn.Linear(768, hidden_size), # 768 -> 256 torch.nn.ReLU(), torch.nn.Dropout(dropout), torch.nn.Linear(hidden_size, hidden_size // 2), # 256 -> 128 torch.nn.ReLU(), torch.nn.Dropout(dropout), torch.nn.Linear(hidden_size // 2, num_labels) # 128 -> 4 ) def forward(self, input_ids, attention_mask): # Get CodeBERT embeddings outputs = self.codebert( input_ids=input_ids, attention_mask=attention_mask ) # CRITICAL: Use [CLS] token from last_hidden_state (matching training) pooled_output = outputs.last_hidden_state[:, 0, :] # Apply dropout pooled_output = self.dropout(pooled_output) # Classification logits = self.classifier(pooled_output) return logits # Use lifespan instead of deprecated on_event @asynccontextmanager async def lifespan(app: FastAPI): # Startup global model, tokenizer, device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load your trained model with EXACT same architecture model = CodeBERTClassifier(num_labels=4, dropout=0.3, hidden_size=256) model.load_state_dict(torch.load('best_codebert_model.pt', map_location=device)) model.to(device) model.eval() tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base') print(f"Model loaded successfully on {device}") yield # Shutdown print("Shutting down...") app = FastAPI(title="Vulnerability Detection API", lifespan=lifespan) class CodeRequest(BaseModel): code: str max_length: int = 512 class VulnerabilityResponse(BaseModel): vulnerability_type: str confidence: float is_vulnerable: bool label: str @app.post("/detect", response_model=VulnerabilityResponse) async def detect_vulnerability(request: CodeRequest): try: # Tokenize encoding = tokenizer( request.code, padding='max_length', truncation=True, max_length=request.max_length, return_tensors='pt' ) input_ids = encoding['input_ids'].to(device) attention_mask = encoding['attention_mask'].to(device) # Predict with torch.no_grad(): logits = model(input_ids, attention_mask) probs = torch.softmax(logits, dim=1) confidence, predicted = torch.max(probs, 1) # Label mapping - VERIFY THIS MATCHES YOUR TRAINING label_map = {0: 's0', 1: 'v0', 2: 's1', 3: 'v1'} vuln_type_map = { 's0': 'SQL Injection', 'v0': 'Certificate Validation', 's1': 'SQL Injection', 'v1': 'Certificate Validation' } label = label_map[predicted.item()] is_vulnerable = label in ['s0', 'v0'] return VulnerabilityResponse( vulnerability_type=vuln_type_map[label], confidence=float(confidence.item()), is_vulnerable=is_vulnerable, label=label ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): return {"status": "healthy", "model_loaded": model is not None} @app.get("/") async def root(): return { "message": "Vulnerability Detection API", "endpoints": ["/detect", "/health", "/docs"] } if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)