Spaces:
Running
Running
| import os | |
| import numpy as np | |
| import torch | |
| import joblib | |
| import torch.nn as nn | |
| from transformers import AutoImageProcessor, AutoModel | |
| from PIL import Image | |
| import requests | |
| import gradio as gr | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ----------------------- | |
| # Your model class (unchanged) | |
| # ----------------------- | |
| class ImageAuthenticityClassifier(nn.Module): | |
| def __init__(self, backbone, w, b): | |
| super().__init__() | |
| self.backbone = backbone | |
| d = w.shape[0] | |
| self.head = nn.Linear(d, 1) | |
| # Load my trained classifier head | |
| with torch.no_grad(): | |
| self.head.weight.copy_( | |
| w.unsqueeze(0).to(dtype=self.head.weight.dtype, | |
| device=self.head.weight.device) | |
| ) | |
| bias_tensor = torch.tensor( | |
| [b], | |
| dtype=self.head.bias.dtype, | |
| device=self.head.bias.device, | |
| ) | |
| self.head.bias.copy_(bias_tensor) | |
| def forward(self, pixel_values, return_tokens: bool = False): | |
| outputs = self.backbone(pixel_values=pixel_values) | |
| hidden = outputs.last_hidden_state | |
| patch_tokens = hidden[:, 1:, :] | |
| emb = patch_tokens.mean(dim = 1) | |
| # Apply classifier head to mean patch token embeddings | |
| logits = self.head(emb) | |
| prob = torch.sigmoid(logits) | |
| if (return_tokens): | |
| return logits, prob, emb, patch_tokens | |
| return logits, prob, emb | |
| # ----------------------- | |
| # Load linear classifier head for logistic regression | |
| # ----------------------- | |
| model_save_path = "logisticRegressionClassifier.joblib" | |
| logisticRegressionClassifier = joblib.load(model_save_path) | |
| coef = logisticRegressionClassifier.coef_ | |
| w = torch.from_numpy(coef.squeeze(0)).float() | |
| intercept = logisticRegressionClassifier.intercept_ | |
| b = float(intercept[0]) | |
| # ----------------------- | |
| # Load DinoV3 backbone + processor (gated repo via token) | |
| # ----------------------- | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| backbone = AutoModel.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN).to(device) | |
| processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN,) | |
| image_auth_model = ImageAuthenticityClassifier(backbone, w, b).to(device) | |
| # ----------------------- | |
| # Inference helper functions (unchanged) | |
| # ----------------------- | |
| def load_image(online_image_url): | |
| img = Image.open(requests.get(online_image_url, stream=True).raw).convert("RGB") | |
| return img | |
| def prepare_pixel_values(img): | |
| inputs = processor(images=img, return_tensors="pt") | |
| pixel_values = inputs["pixel_values"].to(device) | |
| return pixel_values | |
| def predict_from_online_url(online_image_url): | |
| img = load_image(online_image_url) | |
| pixel_values = prepare_pixel_values(img) | |
| with torch.no_grad(): | |
| logits, prob, emb = image_auth_model(pixel_values) | |
| return float(prob[0][0].item()) | |
| # ----------------------- | |
| # Gradio interface exposing ui_predict as a web UI/API. (AI Generated lol) | |
| # ----------------------- | |
| def ui_predict(image_url: str): | |
| if not image_url: | |
| return None, "Awaiting input", "Enter an image URL to run a prediction." | |
| try: | |
| img = load_image(image_url) | |
| ai_prob = float(predict_from_online_url(image_url)) | |
| percent = ai_prob * 100.0 | |
| verdict = "AI-generated" if ai_prob >= 0.5 else "Not AI-generated" | |
| headline = verdict | |
| detail = f"{percent:.1f}% probability the image is AI-generated" | |
| return img, headline, detail | |
| except Exception as e: | |
| return None, "Error", str(e) | |
| demo = gr.Interface( | |
| fn=ui_predict, | |
| inputs=gr.Textbox( | |
| label="Image URL", | |
| placeholder="https://example.com/image.jpg", | |
| ), | |
| outputs=[ | |
| gr.Image(label="Preview"), | |
| gr.Textbox(label="Verdict"), | |
| gr.Textbox(label="Details"), | |
| ], | |
| title="Image Authenticity", | |
| description="Paste an image URL to estimate how likely it is AI-generated.", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |