Lahiru-LK commited on
Commit
e05aa19
·
verified ·
1 Parent(s): 3c39fe2

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +24 -0
  2. app.py +142 -0
  3. requirements.txt +5 -0
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # System dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Python dependencies
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy model file (487 MB)
15
+ COPY best_codebert_model.pt .
16
+
17
+ # Copy application
18
+ COPY app.py .
19
+
20
+ # Expose port
21
+ EXPOSE 7860
22
+
23
+ # Run the application
24
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py - CORRECTED VERSION
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+ import torch
5
+ from transformers import RobertaTokenizer, RobertaModel
6
+ import uvicorn
7
+ from contextlib import asynccontextmanager
8
+
9
+ # Global variables
10
+ model = None
11
+ tokenizer = None
12
+ device = None
13
+
14
+ # THIS MUST MATCH YOUR TRAINING CODE EXACTLY
15
+ class CodeBERTClassifier(torch.nn.Module):
16
+ def __init__(self, num_labels=4, dropout=0.3, hidden_size=256):
17
+ super(CodeBERTClassifier, self).__init__()
18
+
19
+ # Load pre-trained CodeBERT
20
+ self.codebert = RobertaModel.from_pretrained('microsoft/codebert-base')
21
+
22
+ # Dropout
23
+ self.dropout = torch.nn.Dropout(dropout)
24
+
25
+ # Multi-layer feedforward network - MUST MATCH TRAINING
26
+ self.classifier = torch.nn.Sequential(
27
+ torch.nn.Linear(768, hidden_size), # 768 -> 256
28
+ torch.nn.ReLU(),
29
+ torch.nn.Dropout(dropout),
30
+ torch.nn.Linear(hidden_size, hidden_size // 2), # 256 -> 128
31
+ torch.nn.ReLU(),
32
+ torch.nn.Dropout(dropout),
33
+ torch.nn.Linear(hidden_size // 2, num_labels) # 128 -> 4
34
+ )
35
+
36
+ def forward(self, input_ids, attention_mask):
37
+ # Get CodeBERT embeddings
38
+ outputs = self.codebert(
39
+ input_ids=input_ids,
40
+ attention_mask=attention_mask
41
+ )
42
+
43
+ # CRITICAL: Use [CLS] token from last_hidden_state (matching training)
44
+ pooled_output = outputs.last_hidden_state[:, 0, :]
45
+
46
+ # Apply dropout
47
+ pooled_output = self.dropout(pooled_output)
48
+
49
+ # Classification
50
+ logits = self.classifier(pooled_output)
51
+
52
+ return logits
53
+
54
+ # Use lifespan instead of deprecated on_event
55
+ @asynccontextmanager
56
+ async def lifespan(app: FastAPI):
57
+ # Startup
58
+ global model, tokenizer, device
59
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
60
+
61
+ # Load your trained model with EXACT same architecture
62
+ model = CodeBERTClassifier(num_labels=4, dropout=0.3, hidden_size=256)
63
+ model.load_state_dict(torch.load('best_codebert_model.pt', map_location=device))
64
+ model.to(device)
65
+ model.eval()
66
+
67
+ tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
68
+ print(f"Model loaded successfully on {device}")
69
+
70
+ yield
71
+
72
+ # Shutdown
73
+ print("Shutting down...")
74
+
75
+ app = FastAPI(title="Vulnerability Detection API", lifespan=lifespan)
76
+
77
+ class CodeRequest(BaseModel):
78
+ code: str
79
+ max_length: int = 512
80
+
81
+ class VulnerabilityResponse(BaseModel):
82
+ vulnerability_type: str
83
+ confidence: float
84
+ is_vulnerable: bool
85
+ label: str
86
+
87
+ @app.post("/detect", response_model=VulnerabilityResponse)
88
+ async def detect_vulnerability(request: CodeRequest):
89
+ try:
90
+ # Tokenize
91
+ encoding = tokenizer(
92
+ request.code,
93
+ padding='max_length',
94
+ truncation=True,
95
+ max_length=request.max_length,
96
+ return_tensors='pt'
97
+ )
98
+
99
+ input_ids = encoding['input_ids'].to(device)
100
+ attention_mask = encoding['attention_mask'].to(device)
101
+
102
+ # Predict
103
+ with torch.no_grad():
104
+ logits = model(input_ids, attention_mask)
105
+ probs = torch.softmax(logits, dim=1)
106
+ confidence, predicted = torch.max(probs, 1)
107
+
108
+ # Label mapping - VERIFY THIS MATCHES YOUR TRAINING
109
+ label_map = {0: 's0', 1: 'v0', 2: 's1', 3: 'v1'}
110
+ vuln_type_map = {
111
+ 's0': 'SQL Injection',
112
+ 'v0': 'Certificate Validation',
113
+ 's1': 'SQL Injection',
114
+ 'v1': 'Certificate Validation'
115
+ }
116
+
117
+ label = label_map[predicted.item()]
118
+ is_vulnerable = label in ['s0', 'v0']
119
+
120
+ return VulnerabilityResponse(
121
+ vulnerability_type=vuln_type_map[label],
122
+ confidence=float(confidence.item()),
123
+ is_vulnerable=is_vulnerable,
124
+ label=label
125
+ )
126
+
127
+ except Exception as e:
128
+ raise HTTPException(status_code=500, detail=str(e))
129
+
130
+ @app.get("/health")
131
+ async def health_check():
132
+ return {"status": "healthy", "model_loaded": model is not None}
133
+
134
+ @app.get("/")
135
+ async def root():
136
+ return {
137
+ "message": "Vulnerability Detection API",
138
+ "endpoints": ["/detect", "/health", "/docs"]
139
+ }
140
+
141
+ if __name__ == "__main__":
142
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ torch==2.1.0
4
+ transformers==4.35.0
5
+ pydantic==2.5.0