jerry2247 commited on
Commit
e8c13db
·
verified ·
1 Parent(s): b5c0709

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +135 -0
  2. logisticRegressionClassifier.joblib +3 -0
  3. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import joblib
5
+ import torch.nn as nn
6
+ from transformers import AutoImageProcessor, AutoModel
7
+ from PIL import Image
8
+ import requests
9
+ import gradio as gr
10
+
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ # -----------------------
15
+ # Your model class (unchanged)
16
+ # -----------------------
17
+ class ImageAuthenticityClassifier(nn.Module):
18
+ def __init__(self, backbone, w, b):
19
+ super().__init__()
20
+ self.backbone = backbone
21
+
22
+ d = w.shape[0]
23
+ self.head = nn.Linear(d, 1)
24
+
25
+ # Load my trained classifier head
26
+ with torch.no_grad():
27
+ self.head.weight.copy_(
28
+ w.unsqueeze(0).to(dtype=self.head.weight.dtype,
29
+ device=self.head.weight.device)
30
+ )
31
+ bias_tensor = torch.tensor(
32
+ [b],
33
+ dtype=self.head.bias.dtype,
34
+ device=self.head.bias.device,
35
+ )
36
+ self.head.bias.copy_(bias_tensor)
37
+
38
+
39
+ def forward(self, pixel_values, return_tokens: bool = False):
40
+ outputs = self.backbone(pixel_values=pixel_values)
41
+ hidden = outputs.last_hidden_state
42
+
43
+ patch_tokens = hidden[:, 1:, :]
44
+ emb = patch_tokens.mean(dim = 1)
45
+
46
+ # Apply classifier head to mean patch token embeddings
47
+ logits = self.head(emb)
48
+ prob = torch.sigmoid(logits)
49
+
50
+ if (return_tokens):
51
+ return logits, prob, emb, patch_tokens
52
+
53
+ return logits, prob, emb
54
+
55
+
56
+ # -----------------------
57
+ # Load linear classifier head for logistic regression
58
+ # -----------------------
59
+ model_save_path = "logisticRegressionClassifier.joblib"
60
+ logisticRegressionClassifier = joblib.load(model_save_path)
61
+
62
+ coef = logisticRegressionClassifier.coef_
63
+ w = torch.from_numpy(coef.squeeze(0)).float()
64
+ intercept = logisticRegressionClassifier.intercept_
65
+ b = float(intercept[0])
66
+
67
+
68
+ # -----------------------
69
+ # Load DinoV3 backbone + processor (gated repo via token)
70
+ # -----------------------
71
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
72
+ backbone = AutoModel.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN).to(device)
73
+ processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN,)
74
+ image_auth_model = ImageAuthenticityClassifier(backbone, w, b).to(device)
75
+
76
+
77
+ # -----------------------
78
+ # Inference helper functions (unchanged)
79
+ # -----------------------
80
+ def load_image(online_image_url):
81
+ img = Image.open(requests.get(online_image_url, stream=True).raw).convert("RGB")
82
+ return img
83
+
84
+ def prepare_pixel_values(img):
85
+ inputs = processor(images=img, return_tensors="pt")
86
+ pixel_values = inputs["pixel_values"].to(device)
87
+ return pixel_values
88
+
89
+ def predict_from_online_url(online_image_url):
90
+ img = load_image(online_image_url)
91
+ pixel_values = prepare_pixel_values(img)
92
+
93
+ with torch.no_grad():
94
+ logits, prob, emb = image_auth_model(pixel_values)
95
+ return float(prob[0][0].item())
96
+
97
+
98
+ # -----------------------
99
+ # Gradio interface exposing ui_predict as a web UI/API. (AI Generated lol)
100
+ # -----------------------
101
+
102
+ def ui_predict(image_url: str):
103
+ if not image_url:
104
+ return None, "Awaiting input", "Enter an image URL to run a prediction."
105
+
106
+ try:
107
+ img = load_image(image_url)
108
+ ai_prob = float(predict_from_online_url(image_url))
109
+ percent = ai_prob * 100.0
110
+
111
+ verdict = "AI-generated" if ai_prob >= 0.5 else "Not AI-generated"
112
+ headline = verdict
113
+ detail = f"{percent:.1f}% probability the image is AI-generated"
114
+
115
+ return img, headline, detail
116
+ except Exception as e:
117
+ return None, "Error", str(e)
118
+
119
+ demo = gr.Interface(
120
+ fn=ui_predict,
121
+ inputs=gr.Textbox(
122
+ label="Image URL",
123
+ placeholder="https://example.com/image.jpg",
124
+ ),
125
+ outputs=[
126
+ gr.Image(label="Preview"),
127
+ gr.Textbox(label="Verdict"),
128
+ gr.Textbox(label="Details"),
129
+ ],
130
+ title="Image Authenticity",
131
+ description="Paste an image URL to estimate how likely it is AI-generated.",
132
+ )
133
+
134
+ if __name__ == "__main__":
135
+ demo.launch()
logisticRegressionClassifier.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3162743076d6b843efa03f16727093be0c4b95aed4c9b5c550e208812fbafc9
3
+ size 7007
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ scikit-learn
5
+ joblib
6
+ numpy
7
+ pillow
8
+ gradio
9
+ requests