File size: 4,452 Bytes
e05aa19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# app.py - CORRECTED VERSION
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import RobertaTokenizer, RobertaModel
import uvicorn
from contextlib import asynccontextmanager
# Global variables
model = None
tokenizer = None
device = None
# THIS MUST MATCH YOUR TRAINING CODE EXACTLY
class CodeBERTClassifier(torch.nn.Module):
def __init__(self, num_labels=4, dropout=0.3, hidden_size=256):
super(CodeBERTClassifier, self).__init__()
# Load pre-trained CodeBERT
self.codebert = RobertaModel.from_pretrained('microsoft/codebert-base')
# Dropout
self.dropout = torch.nn.Dropout(dropout)
# Multi-layer feedforward network - MUST MATCH TRAINING
self.classifier = torch.nn.Sequential(
torch.nn.Linear(768, hidden_size), # 768 -> 256
torch.nn.ReLU(),
torch.nn.Dropout(dropout),
torch.nn.Linear(hidden_size, hidden_size // 2), # 256 -> 128
torch.nn.ReLU(),
torch.nn.Dropout(dropout),
torch.nn.Linear(hidden_size // 2, num_labels) # 128 -> 4
)
def forward(self, input_ids, attention_mask):
# Get CodeBERT embeddings
outputs = self.codebert(
input_ids=input_ids,
attention_mask=attention_mask
)
# CRITICAL: Use [CLS] token from last_hidden_state (matching training)
pooled_output = outputs.last_hidden_state[:, 0, :]
# Apply dropout
pooled_output = self.dropout(pooled_output)
# Classification
logits = self.classifier(pooled_output)
return logits
# Use lifespan instead of deprecated on_event
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
global model, tokenizer, device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load your trained model with EXACT same architecture
model = CodeBERTClassifier(num_labels=4, dropout=0.3, hidden_size=256)
model.load_state_dict(torch.load('best_codebert_model.pt', map_location=device))
model.to(device)
model.eval()
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
print(f"Model loaded successfully on {device}")
yield
# Shutdown
print("Shutting down...")
app = FastAPI(title="Vulnerability Detection API", lifespan=lifespan)
class CodeRequest(BaseModel):
code: str
max_length: int = 512
class VulnerabilityResponse(BaseModel):
vulnerability_type: str
confidence: float
is_vulnerable: bool
label: str
@app.post("/detect", response_model=VulnerabilityResponse)
async def detect_vulnerability(request: CodeRequest):
try:
# Tokenize
encoding = tokenizer(
request.code,
padding='max_length',
truncation=True,
max_length=request.max_length,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
# Predict
with torch.no_grad():
logits = model(input_ids, attention_mask)
probs = torch.softmax(logits, dim=1)
confidence, predicted = torch.max(probs, 1)
# Label mapping - VERIFY THIS MATCHES YOUR TRAINING
label_map = {0: 's0', 1: 'v0', 2: 's1', 3: 'v1'}
vuln_type_map = {
's0': 'SQL Injection',
'v0': 'Certificate Validation',
's1': 'SQL Injection',
'v1': 'Certificate Validation'
}
label = label_map[predicted.item()]
is_vulnerable = label in ['s0', 'v0']
return VulnerabilityResponse(
vulnerability_type=vuln_type_map[label],
confidence=float(confidence.item()),
is_vulnerable=is_vulnerable,
label=label
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy", "model_loaded": model is not None}
@app.get("/")
async def root():
return {
"message": "Vulnerability Detection API",
"endpoints": ["/detect", "/health", "/docs"]
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000) |