jerry2247's picture
Upload 3 files
e8c13db verified
raw
history blame
3.97 kB
import os
import numpy as np
import torch
import joblib
import torch.nn as nn
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import requests
import gradio as gr
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -----------------------
# Your model class (unchanged)
# -----------------------
class ImageAuthenticityClassifier(nn.Module):
def __init__(self, backbone, w, b):
super().__init__()
self.backbone = backbone
d = w.shape[0]
self.head = nn.Linear(d, 1)
# Load my trained classifier head
with torch.no_grad():
self.head.weight.copy_(
w.unsqueeze(0).to(dtype=self.head.weight.dtype,
device=self.head.weight.device)
)
bias_tensor = torch.tensor(
[b],
dtype=self.head.bias.dtype,
device=self.head.bias.device,
)
self.head.bias.copy_(bias_tensor)
def forward(self, pixel_values, return_tokens: bool = False):
outputs = self.backbone(pixel_values=pixel_values)
hidden = outputs.last_hidden_state
patch_tokens = hidden[:, 1:, :]
emb = patch_tokens.mean(dim = 1)
# Apply classifier head to mean patch token embeddings
logits = self.head(emb)
prob = torch.sigmoid(logits)
if (return_tokens):
return logits, prob, emb, patch_tokens
return logits, prob, emb
# -----------------------
# Load linear classifier head for logistic regression
# -----------------------
model_save_path = "logisticRegressionClassifier.joblib"
logisticRegressionClassifier = joblib.load(model_save_path)
coef = logisticRegressionClassifier.coef_
w = torch.from_numpy(coef.squeeze(0)).float()
intercept = logisticRegressionClassifier.intercept_
b = float(intercept[0])
# -----------------------
# Load DinoV3 backbone + processor (gated repo via token)
# -----------------------
HF_TOKEN = os.environ.get("HF_TOKEN", None)
backbone = AutoModel.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN).to(device)
processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN,)
image_auth_model = ImageAuthenticityClassifier(backbone, w, b).to(device)
# -----------------------
# Inference helper functions (unchanged)
# -----------------------
def load_image(online_image_url):
img = Image.open(requests.get(online_image_url, stream=True).raw).convert("RGB")
return img
def prepare_pixel_values(img):
inputs = processor(images=img, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(device)
return pixel_values
def predict_from_online_url(online_image_url):
img = load_image(online_image_url)
pixel_values = prepare_pixel_values(img)
with torch.no_grad():
logits, prob, emb = image_auth_model(pixel_values)
return float(prob[0][0].item())
# -----------------------
# Gradio interface exposing ui_predict as a web UI/API. (AI Generated lol)
# -----------------------
def ui_predict(image_url: str):
if not image_url:
return None, "Awaiting input", "Enter an image URL to run a prediction."
try:
img = load_image(image_url)
ai_prob = float(predict_from_online_url(image_url))
percent = ai_prob * 100.0
verdict = "AI-generated" if ai_prob >= 0.5 else "Not AI-generated"
headline = verdict
detail = f"{percent:.1f}% probability the image is AI-generated"
return img, headline, detail
except Exception as e:
return None, "Error", str(e)
demo = gr.Interface(
fn=ui_predict,
inputs=gr.Textbox(
label="Image URL",
placeholder="https://example.com/image.jpg",
),
outputs=[
gr.Image(label="Preview"),
gr.Textbox(label="Verdict"),
gr.Textbox(label="Details"),
],
title="Image Authenticity",
description="Paste an image URL to estimate how likely it is AI-generated.",
)
if __name__ == "__main__":
demo.launch()