Spaces:
Build error
Build error
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| from ResNet_for_CC import CC_model | |
| # Initialize the model | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = CC_model() | |
| # Load the pre-trained weights, adjusting for DataParallel if necessary | |
| model_path = 'CC_net.pt' | |
| checkpoint = torch.load(model_path, map_location=device) | |
| if any(key.startswith('module.') for key in checkpoint.keys()): | |
| checkpoint = {k.replace('module.', ''): v for k, v in checkpoint.items()} | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| model.to(device) | |
| # Image preprocessing | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Define class names from category_names_eng.txt | |
| class_names = [ | |
| 'T-Shirt', 'Shirt', 'Knitwear', 'Chiffon', 'Sweater', 'Hoodie', | |
| 'Windbreaker', 'Jacket', 'Downcoat', 'Suit', 'Shawl', 'Dress', | |
| 'Vest', 'Underwear' | |
| ] | |
| def predict(image): | |
| # Convert Gradio Image to PIL and preprocess | |
| img = Image.fromarray(image.astype('uint8'), 'RGB') | |
| img = preprocess(img).unsqueeze(0).to(device) | |
| # Generate predictions | |
| with torch.no_grad(): | |
| dr_feature, output_mean = model(img) | |
| # Get the predicted class | |
| _, predicted = torch.max(output_mean, 1) | |
| predicted_class = class_names[predicted.item()] | |
| # Format output | |
| return f"Predicted class: {predicted_class}" | |
| return f"Class number: {predicted.item()}" | |
| # Example images from Hugging Face | |
| examples = [ | |
| ["example_image(1).JPG"], | |
| ["example_image(2).jpg"], | |
| ["example_image(3).jpg"], | |
| ["example_image(4).webp"], | |
| ["example_image(5).webp"], | |
| ["example_image(6).webp"] | |
| ] | |
| # Gradio Interface | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(label="Upload Clothing Image"), | |
| outputs=gr.Textbox(label="Prediction"), | |
| title="Clothing Image Classifier", | |
| description="This model classifies clothing images using ResNet50. Try out different examples below for a quick demonstration!", | |
| examples=examples | |
| ) | |
| interface.launch() |