Spaces:
Runtime error
Runtime error
| import os | |
| import subprocess | |
| import sys | |
| import cv2 | |
| import gdown | |
| from PIL import Image | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| from torchvision import transforms | |
| def setup_env(path='Variations-of-SFANet-for-Crowd-Counting'): | |
| if os.path.exists(path): | |
| return path | |
| subprocess.run( | |
| [ | |
| 'git', | |
| 'clone', | |
| f'https://github.com/Pongpisit-Thanasutives/{path}.git', | |
| f'{path}', | |
| ], | |
| capture_output=True, | |
| check=True, | |
| ) | |
| sys.path.append(path) | |
| with open(os.path.join(path, 'models', '__init__.py'), 'w') as f: | |
| f.write('') | |
| return path | |
| def get_model(path, weights): | |
| from models import M_SFANet_UCF_QNRF | |
| model = M_SFANet_UCF_QNRF.Model() | |
| model.load_state_dict( | |
| torch.load(weights, map_location=torch.device('cpu'))) | |
| return model.eval() | |
| def download_weights( | |
| url='https://drive.google.com/uc?id=1fGuH4o0hKbgdP1kaj9rbjX2HUL1IH0oo', | |
| out="Paper's_weights_UCF_QNRF.zip", | |
| ): | |
| weights = "Paper's_weights_UCF_QNRF/best_M-SFANet*_UCF_QNRF.pth" | |
| if os.path.exists(weights): | |
| return weights | |
| gdown.download(url, out) | |
| subprocess.run( | |
| ['unzip', out], | |
| capture_output=True, | |
| check=True, | |
| ) | |
| return weights | |
| def transform_image(img): | |
| trans = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| height, width = img.size[1], img.size[0] | |
| height = round(height / 16) * 16 | |
| width = round(width / 16) * 16 | |
| img = cv2.resize(np.array(img), (width, height), cv2.INTER_CUBIC) | |
| return trans(Image.fromarray(img))[None, :] | |
| def main(): | |
| st.write("Demo of [Encoder-Decoder Based Convolutional Neural Networks with Multi-Scale-Aware Modules for Crowd Counting](https://arxiv.org/abs/2003.05586)") # noqa | |
| path = setup_env() | |
| weights = download_weights() | |
| model = get_model(path, weights) | |
| image_file = st.file_uploader( | |
| "Upload image", type=['png', 'jpg', 'jpeg']) | |
| if image_file is not None: | |
| image = Image.open(image_file).convert('RGB') | |
| st.image(image) | |
| density_map = model(transform_image(image)) | |
| density_map_img = density_map.detach().numpy()[0].transpose(1, 2, 0) | |
| st.image(density_map_img / density_map_img.max()) | |
| st.write("Estimated count: ", torch.sum(density_map).item()) | |
| else: | |
| st.write("Example image to use that you can drag and drop:") | |
| st.image(Image.open('crowd.jpg').convert('RGB')) | |
| main() | |