File size: 4,452 Bytes
e05aa19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# app.py - CORRECTED VERSION
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import RobertaTokenizer, RobertaModel
import uvicorn
from contextlib import asynccontextmanager

# Global variables
model = None
tokenizer = None
device = None

# THIS MUST MATCH YOUR TRAINING CODE EXACTLY
class CodeBERTClassifier(torch.nn.Module):
    def __init__(self, num_labels=4, dropout=0.3, hidden_size=256):
        super(CodeBERTClassifier, self).__init__()
        
        # Load pre-trained CodeBERT
        self.codebert = RobertaModel.from_pretrained('microsoft/codebert-base')
        
        # Dropout
        self.dropout = torch.nn.Dropout(dropout)
        
        # Multi-layer feedforward network - MUST MATCH TRAINING
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(768, hidden_size),  # 768 -> 256
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_size, hidden_size // 2),  # 256 -> 128
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_size // 2, num_labels)  # 128 -> 4
        )
    
    def forward(self, input_ids, attention_mask):
        # Get CodeBERT embeddings
        outputs = self.codebert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        # CRITICAL: Use [CLS] token from last_hidden_state (matching training)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        
        # Apply dropout
        pooled_output = self.dropout(pooled_output)
        
        # Classification
        logits = self.classifier(pooled_output)
        
        return logits

# Use lifespan instead of deprecated on_event
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup
    global model, tokenizer, device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load your trained model with EXACT same architecture
    model = CodeBERTClassifier(num_labels=4, dropout=0.3, hidden_size=256)
    model.load_state_dict(torch.load('best_codebert_model.pt', map_location=device))
    model.to(device)
    model.eval()
    
    tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
    print(f"Model loaded successfully on {device}")
    
    yield
    
    # Shutdown
    print("Shutting down...")

app = FastAPI(title="Vulnerability Detection API", lifespan=lifespan)

class CodeRequest(BaseModel):
    code: str
    max_length: int = 512

class VulnerabilityResponse(BaseModel):
    vulnerability_type: str
    confidence: float
    is_vulnerable: bool
    label: str

@app.post("/detect", response_model=VulnerabilityResponse)
async def detect_vulnerability(request: CodeRequest):
    try:
        # Tokenize
        encoding = tokenizer(
            request.code,
            padding='max_length',
            truncation=True,
            max_length=request.max_length,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)
        
        # Predict
        with torch.no_grad():
            logits = model(input_ids, attention_mask)
            probs = torch.softmax(logits, dim=1)
            confidence, predicted = torch.max(probs, 1)
        
        # Label mapping - VERIFY THIS MATCHES YOUR TRAINING
        label_map = {0: 's0', 1: 'v0', 2: 's1', 3: 'v1'}
        vuln_type_map = {
            's0': 'SQL Injection', 
            'v0': 'Certificate Validation',
            's1': 'SQL Injection', 
            'v1': 'Certificate Validation'
        }
        
        label = label_map[predicted.item()]
        is_vulnerable = label in ['s0', 'v0']
        
        return VulnerabilityResponse(
            vulnerability_type=vuln_type_map[label],
            confidence=float(confidence.item()),
            is_vulnerable=is_vulnerable,
            label=label
        )
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    return {"status": "healthy", "model_loaded": model is not None}

@app.get("/")
async def root():
    return {
        "message": "Vulnerability Detection API",
        "endpoints": ["/detect", "/health", "/docs"]
    }

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)