Commit
·
b412062
0
Parent(s):
Deploy Chest X-Ray App (LFS)
Browse files- .gitattributes +4 -0
- .gitignore +45 -0
- .gradio/certificate.pem +31 -0
- README.md +28 -0
- data/Chest_Xray_PA_3-8-2010.png +3 -0
- data/cxr14_subset_labels.csv +0 -0
- data/large-pneumothorax-5.jpeg +3 -0
- data/precomputed_image_embeddings.npz +3 -0
- data/precomputed_text_embeddings.npz +3 -0
- data/test_xray.png +3 -0
- requirements.txt +13 -0
- results/kaggle_predictions.csv +251 -0
- results/kaggle_roc_curve.png +3 -0
- results/roc_PNEUMOTHORAX.png +3 -0
- src/app.py +110 -0
- src/calculate_threshold.py +54 -0
- src/create_dummy_image.py +13 -0
- src/dicom_utils.py +49 -0
- src/download.py +43 -0
- src/evaluate.py +43 -0
- src/evaluate_kaggle.py +142 -0
- src/main.py +106 -0
- src/model.py +171 -0
- src/plot_kaggle_roc.py +62 -0
.gitattributes
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
env/
|
| 8 |
+
build/
|
| 9 |
+
develop-eggs/
|
| 10 |
+
dist/
|
| 11 |
+
downloads/
|
| 12 |
+
eggs/
|
| 13 |
+
.eggs/
|
| 14 |
+
lib/
|
| 15 |
+
lib64/
|
| 16 |
+
parts/
|
| 17 |
+
sdist/
|
| 18 |
+
var/
|
| 19 |
+
wheels/
|
| 20 |
+
*.egg-info/
|
| 21 |
+
.installed.cfg
|
| 22 |
+
*.egg
|
| 23 |
+
|
| 24 |
+
# MacOS
|
| 25 |
+
.DS_Store
|
| 26 |
+
|
| 27 |
+
# Virtual Env
|
| 28 |
+
venv/
|
| 29 |
+
miniconda3/
|
| 30 |
+
|
| 31 |
+
# Project Data
|
| 32 |
+
# Exclude the large Kaggle dataset
|
| 33 |
+
data/kaggle/
|
| 34 |
+
# Exclude raw image download if any (keep precomputed embeddings)
|
| 35 |
+
# data/*.png
|
| 36 |
+
# data/*.jpg
|
| 37 |
+
# data/*.jpeg
|
| 38 |
+
# data/*.dcm
|
| 39 |
+
|
| 40 |
+
# But FORCE include the necessary precomputed files and results for the app
|
| 41 |
+
!data/precomputed_text_embeddings.npz
|
| 42 |
+
!data/cxr14_subset_labels.csv
|
| 43 |
+
!results/kaggle_roc_curve.png
|
| 44 |
+
!results/roc_PNEUMOTHORAX.png
|
| 45 |
+
!data/google-health/
|
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
README.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Chest X-Ray Zero-Shot Classifier
|
| 3 |
+
emoji: 🩻
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.47.2
|
| 8 |
+
app_file: src/app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# Chest X-Ray Zero-Shot Classifier
|
| 14 |
+
|
| 15 |
+
This application uses the **Google CXR Foundation** model to perform zero-shot classification of Chest X-Rays for **Pneumothorax**.
|
| 16 |
+
|
| 17 |
+
## Detection Logic
|
| 18 |
+
- **Model**: `google/cxr-foundation` (ELIXR-C Image Encoder + QFormer)
|
| 19 |
+
- **Method**: Zero-Shot Classification comparing image embeddings to text embeddings ("small pneumothorax" vs "no pneumothorax").
|
| 20 |
+
- **Binary Threshold**: `-0.1173` (Calibrated on a local Kaggle Pneumothorax dataset using Youden's J statistic).
|
| 21 |
+
|
| 22 |
+
## Performance
|
| 23 |
+
- **Local Kaggle Dataset AUC**: 0.8804
|
| 24 |
+
|
| 25 |
+
## How to use
|
| 26 |
+
1. Upload a valid Chest X-Ray (PNG, JPG, or DICOM).
|
| 27 |
+
2. Click "Analyze Image".
|
| 28 |
+
3. View the prediction and confidence score.
|
data/Chest_Xray_PA_3-8-2010.png
ADDED
|
Git LFS Details
|
data/cxr14_subset_labels.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/large-pneumothorax-5.jpeg
ADDED
|
Git LFS Details
|
data/precomputed_image_embeddings.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1c307ffb4f109160e1e64780e4c8522642d6581badb593d94717634c4a76574e
|
| 3 |
+
size 45543702
|
data/precomputed_text_embeddings.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3de8680df438bf1adda19f8229179aecfc43c050771bb2582af4964fb76fa1d6
|
| 3 |
+
size 931906
|
data/test_xray.png
ADDED
|
Git LFS Details
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pandas
|
| 2 |
+
numpy
|
| 3 |
+
scikit-learn
|
| 4 |
+
matplotlib
|
| 5 |
+
wget
|
| 6 |
+
huggingface_hub
|
| 7 |
+
pypng
|
| 8 |
+
Pillow
|
| 9 |
+
tensorflow-text; sys_platform != 'win32'
|
| 10 |
+
# tensorflow-metal; sys_platform == 'darwin' and platform_machine == 'arm64' # Disabled to avoid conflict with tensorflow-text
|
| 11 |
+
tensorflow
|
| 12 |
+
pydicom
|
| 13 |
+
gradio
|
results/kaggle_predictions.csv
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
file,true_label,pneumothorax_score
|
| 2 |
+
train/No Pneumothorax/000000.dcm,No Pneumothorax,-0.10425025224685669
|
| 3 |
+
train/Pneumothorax/000001.dcm,Pneumothorax,-0.0685623288154602
|
| 4 |
+
train/No Pneumothorax/000002.dcm,No Pneumothorax,-0.16316622495651245
|
| 5 |
+
train/Pneumothorax/000003.dcm,Pneumothorax,-0.09676074981689453
|
| 6 |
+
train/Pneumothorax/000004.dcm,Pneumothorax,-0.1509198546409607
|
| 7 |
+
train/No Pneumothorax/000005.dcm,No Pneumothorax,-0.1722637414932251
|
| 8 |
+
train/No Pneumothorax/000006.dcm,No Pneumothorax,-0.10494697093963623
|
| 9 |
+
train/No Pneumothorax/000007.dcm,No Pneumothorax,-0.21636486053466797
|
| 10 |
+
train/No Pneumothorax/000008.dcm,No Pneumothorax,-0.20555296540260315
|
| 11 |
+
train/No Pneumothorax/000009.dcm,No Pneumothorax,-0.21164917945861816
|
| 12 |
+
train/Pneumothorax/000010.dcm,Pneumothorax,-0.0807306170463562
|
| 13 |
+
train/No Pneumothorax/000011.dcm,No Pneumothorax,-0.18297719955444336
|
| 14 |
+
train/No Pneumothorax/000012.dcm,No Pneumothorax,-0.20279890298843384
|
| 15 |
+
train/Pneumothorax/000013.dcm,Pneumothorax,-0.13381218910217285
|
| 16 |
+
train/No Pneumothorax/000014.dcm,No Pneumothorax,-0.15488970279693604
|
| 17 |
+
train/No Pneumothorax/000015.dcm,No Pneumothorax,-0.11352282762527466
|
| 18 |
+
train/Pneumothorax/000016.dcm,Pneumothorax,-0.15156656503677368
|
| 19 |
+
train/No Pneumothorax/000017.dcm,No Pneumothorax,-0.12871450185775757
|
| 20 |
+
train/No Pneumothorax/000018.dcm,No Pneumothorax,-0.1893666386604309
|
| 21 |
+
train/No Pneumothorax/000019.dcm,No Pneumothorax,-0.19539695978164673
|
| 22 |
+
train/No Pneumothorax/000020.dcm,No Pneumothorax,-0.14108973741531372
|
| 23 |
+
train/Pneumothorax/000021.dcm,Pneumothorax,-0.1235361099243164
|
| 24 |
+
train/No Pneumothorax/000022.dcm,No Pneumothorax,-0.18652945756912231
|
| 25 |
+
train/No Pneumothorax/000023.dcm,No Pneumothorax,-0.217128723859787
|
| 26 |
+
train/No Pneumothorax/000024.dcm,No Pneumothorax,-0.17158114910125732
|
| 27 |
+
train/No Pneumothorax/000025.dcm,No Pneumothorax,-0.15493667125701904
|
| 28 |
+
train/No Pneumothorax/000026.dcm,No Pneumothorax,-0.19961446523666382
|
| 29 |
+
train/No Pneumothorax/000027.dcm,No Pneumothorax,-0.046631813049316406
|
| 30 |
+
train/No Pneumothorax/000028.dcm,No Pneumothorax,-0.13155603408813477
|
| 31 |
+
train/Pneumothorax/000029.dcm,Pneumothorax,-0.08408623933792114
|
| 32 |
+
train/No Pneumothorax/000030.dcm,No Pneumothorax,-0.1320541501045227
|
| 33 |
+
train/No Pneumothorax/000031.dcm,No Pneumothorax,-0.20212656259536743
|
| 34 |
+
train/No Pneumothorax/000032.dcm,No Pneumothorax,-0.1174507737159729
|
| 35 |
+
train/No Pneumothorax/000033.dcm,No Pneumothorax,-0.22716906666755676
|
| 36 |
+
train/No Pneumothorax/000034.dcm,No Pneumothorax,-0.19766950607299805
|
| 37 |
+
train/No Pneumothorax/000035.dcm,No Pneumothorax,-0.14044338464736938
|
| 38 |
+
train/No Pneumothorax/000036.dcm,No Pneumothorax,-0.20672184228897095
|
| 39 |
+
train/No Pneumothorax/000037.dcm,No Pneumothorax,-0.12445247173309326
|
| 40 |
+
train/Pneumothorax/000038.dcm,Pneumothorax,-0.0037876367568969727
|
| 41 |
+
train/No Pneumothorax/000039.dcm,No Pneumothorax,-0.10316938161849976
|
| 42 |
+
train/No Pneumothorax/000040.dcm,No Pneumothorax,-0.17113035917282104
|
| 43 |
+
train/No Pneumothorax/000041.dcm,No Pneumothorax,-0.19905823469161987
|
| 44 |
+
train/Pneumothorax/000042.dcm,Pneumothorax,-0.16009891033172607
|
| 45 |
+
train/Pneumothorax/000043.dcm,Pneumothorax,-0.03364861011505127
|
| 46 |
+
train/No Pneumothorax/000044.dcm,No Pneumothorax,-0.09765869379043579
|
| 47 |
+
train/No Pneumothorax/000045.dcm,No Pneumothorax,-0.1575358510017395
|
| 48 |
+
train/Pneumothorax/000046.dcm,Pneumothorax,-0.07842230796813965
|
| 49 |
+
train/No Pneumothorax/000047.dcm,No Pneumothorax,-0.11831718683242798
|
| 50 |
+
train/No Pneumothorax/000048.dcm,No Pneumothorax,-0.13761907815933228
|
| 51 |
+
train/Pneumothorax/000049.dcm,Pneumothorax,-0.07772612571716309
|
| 52 |
+
train/Pneumothorax/000050.dcm,Pneumothorax,-0.06707853078842163
|
| 53 |
+
train/No Pneumothorax/000051.dcm,No Pneumothorax,-0.14870381355285645
|
| 54 |
+
train/Pneumothorax/000052.dcm,Pneumothorax,-0.14959675073623657
|
| 55 |
+
train/No Pneumothorax/000053.dcm,No Pneumothorax,-0.10760456323623657
|
| 56 |
+
train/No Pneumothorax/000054.dcm,No Pneumothorax,-0.12791478633880615
|
| 57 |
+
train/Pneumothorax/000055.dcm,Pneumothorax,-0.08234262466430664
|
| 58 |
+
train/No Pneumothorax/000056.dcm,No Pneumothorax,-0.17744576930999756
|
| 59 |
+
train/No Pneumothorax/000057.dcm,No Pneumothorax,-0.09609729051589966
|
| 60 |
+
train/Pneumothorax/000058.dcm,Pneumothorax,-0.11940544843673706
|
| 61 |
+
train/No Pneumothorax/000059.dcm,No Pneumothorax,-0.14210587739944458
|
| 62 |
+
train/Pneumothorax/000060.dcm,Pneumothorax,-0.10487139225006104
|
| 63 |
+
train/Pneumothorax/000061.dcm,Pneumothorax,-0.09002417325973511
|
| 64 |
+
train/No Pneumothorax/000062.dcm,No Pneumothorax,-0.12446761131286621
|
| 65 |
+
train/No Pneumothorax/000063.dcm,No Pneumothorax,-0.14532139897346497
|
| 66 |
+
train/Pneumothorax/000064.dcm,Pneumothorax,-0.056851863861083984
|
| 67 |
+
train/Pneumothorax/000065.dcm,Pneumothorax,-0.0926550030708313
|
| 68 |
+
train/No Pneumothorax/000066.dcm,No Pneumothorax,-0.1890377700328827
|
| 69 |
+
train/No Pneumothorax/000067.dcm,No Pneumothorax,-0.153730571269989
|
| 70 |
+
train/Pneumothorax/000068.dcm,Pneumothorax,-0.10636574029922485
|
| 71 |
+
train/No Pneumothorax/000069.dcm,No Pneumothorax,-0.15283668041229248
|
| 72 |
+
train/Pneumothorax/000070.dcm,Pneumothorax,-0.06824761629104614
|
| 73 |
+
train/No Pneumothorax/000071.dcm,No Pneumothorax,-0.13752543926239014
|
| 74 |
+
train/No Pneumothorax/000072.dcm,No Pneumothorax,-0.17813771963119507
|
| 75 |
+
train/No Pneumothorax/000073.dcm,No Pneumothorax,-0.09899967908859253
|
| 76 |
+
train/No Pneumothorax/000074.dcm,No Pneumothorax,-0.17760634422302246
|
| 77 |
+
train/No Pneumothorax/000075.dcm,No Pneumothorax,-0.1572226881980896
|
| 78 |
+
train/Pneumothorax/000076.dcm,Pneumothorax,-0.14221316576004028
|
| 79 |
+
train/No Pneumothorax/000077.dcm,No Pneumothorax,-0.2066076397895813
|
| 80 |
+
train/No Pneumothorax/000078.dcm,No Pneumothorax,-0.12766695022583008
|
| 81 |
+
train/No Pneumothorax/000079.dcm,No Pneumothorax,-0.19809648394584656
|
| 82 |
+
train/No Pneumothorax/000080.dcm,No Pneumothorax,-0.10679352283477783
|
| 83 |
+
train/Pneumothorax/000081.dcm,Pneumothorax,-0.09845387935638428
|
| 84 |
+
train/No Pneumothorax/000082.dcm,No Pneumothorax,-0.17694079875946045
|
| 85 |
+
train/Pneumothorax/000083.dcm,Pneumothorax,-0.13586944341659546
|
| 86 |
+
train/Pneumothorax/000084.dcm,Pneumothorax,-0.058757483959198
|
| 87 |
+
train/No Pneumothorax/000085.dcm,No Pneumothorax,-0.13550400733947754
|
| 88 |
+
train/Pneumothorax/000086.dcm,Pneumothorax,-0.0947638750076294
|
| 89 |
+
train/No Pneumothorax/000087.dcm,No Pneumothorax,-0.09306931495666504
|
| 90 |
+
train/No Pneumothorax/000088.dcm,No Pneumothorax,-0.15457135438919067
|
| 91 |
+
train/No Pneumothorax/000089.dcm,No Pneumothorax,-0.1434255838394165
|
| 92 |
+
train/No Pneumothorax/000090.dcm,No Pneumothorax,-0.15700900554656982
|
| 93 |
+
train/No Pneumothorax/000091.dcm,No Pneumothorax,-0.1415807604789734
|
| 94 |
+
train/No Pneumothorax/000092.dcm,No Pneumothorax,-0.1890210211277008
|
| 95 |
+
train/No Pneumothorax/000093.dcm,No Pneumothorax,-0.12824022769927979
|
| 96 |
+
train/No Pneumothorax/000094.dcm,No Pneumothorax,-0.20132136344909668
|
| 97 |
+
train/No Pneumothorax/000095.dcm,No Pneumothorax,-0.21727478504180908
|
| 98 |
+
train/No Pneumothorax/000096.dcm,No Pneumothorax,-0.1915437877178192
|
| 99 |
+
train/No Pneumothorax/000097.dcm,No Pneumothorax,-0.1858217716217041
|
| 100 |
+
train/No Pneumothorax/000098.dcm,No Pneumothorax,-0.1721271276473999
|
| 101 |
+
train/No Pneumothorax/000099.dcm,No Pneumothorax,-0.16763722896575928
|
| 102 |
+
train/No Pneumothorax/000100.dcm,No Pneumothorax,-0.17117935419082642
|
| 103 |
+
train/Pneumothorax/000101.dcm,Pneumothorax,-0.11121219396591187
|
| 104 |
+
train/No Pneumothorax/000102.dcm,No Pneumothorax,-0.1796356439590454
|
| 105 |
+
train/No Pneumothorax/000103.dcm,No Pneumothorax,-0.21206533908843994
|
| 106 |
+
train/No Pneumothorax/000104.dcm,No Pneumothorax,-0.16594678163528442
|
| 107 |
+
train/No Pneumothorax/000105.dcm,No Pneumothorax,-0.13096767663955688
|
| 108 |
+
train/Pneumothorax/000106.dcm,Pneumothorax,-0.10808700323104858
|
| 109 |
+
train/Pneumothorax/000107.dcm,Pneumothorax,-0.054500699043273926
|
| 110 |
+
train/No Pneumothorax/000108.dcm,No Pneumothorax,-0.17718034982681274
|
| 111 |
+
train/No Pneumothorax/000109.dcm,No Pneumothorax,-0.18512558937072754
|
| 112 |
+
train/No Pneumothorax/000110.dcm,No Pneumothorax,-0.14153480529785156
|
| 113 |
+
train/Pneumothorax/000111.dcm,Pneumothorax,-0.07206696271896362
|
| 114 |
+
train/No Pneumothorax/000112.dcm,No Pneumothorax,-0.15524542331695557
|
| 115 |
+
train/No Pneumothorax/000113.dcm,No Pneumothorax,-0.049727559089660645
|
| 116 |
+
train/No Pneumothorax/000114.dcm,No Pneumothorax,-0.13511207699775696
|
| 117 |
+
train/No Pneumothorax/000115.dcm,No Pneumothorax,-0.12422651052474976
|
| 118 |
+
train/No Pneumothorax/000116.dcm,No Pneumothorax,-0.18628454208374023
|
| 119 |
+
train/Pneumothorax/000117.dcm,Pneumothorax,-0.10177350044250488
|
| 120 |
+
train/No Pneumothorax/000118.dcm,No Pneumothorax,-0.21400070190429688
|
| 121 |
+
train/No Pneumothorax/000119.dcm,No Pneumothorax,-0.1408158540725708
|
| 122 |
+
train/Pneumothorax/000120.dcm,Pneumothorax,0.0021948814392089844
|
| 123 |
+
train/No Pneumothorax/000121.dcm,No Pneumothorax,-0.19585365056991577
|
| 124 |
+
train/No Pneumothorax/000122.dcm,No Pneumothorax,-0.21072477102279663
|
| 125 |
+
train/No Pneumothorax/000123.dcm,No Pneumothorax,-0.1782313585281372
|
| 126 |
+
train/No Pneumothorax/000124.dcm,No Pneumothorax,-0.11618930101394653
|
| 127 |
+
train/Pneumothorax/000125.dcm,Pneumothorax,-0.09667110443115234
|
| 128 |
+
train/Pneumothorax/000126.dcm,Pneumothorax,-0.1602795124053955
|
| 129 |
+
train/No Pneumothorax/000127.dcm,No Pneumothorax,-0.19403642416000366
|
| 130 |
+
train/No Pneumothorax/000128.dcm,No Pneumothorax,-0.09534329175949097
|
| 131 |
+
train/Pneumothorax/000129.dcm,Pneumothorax,-0.060146450996398926
|
| 132 |
+
train/Pneumothorax/000130.dcm,Pneumothorax,-0.07106363773345947
|
| 133 |
+
train/No Pneumothorax/000131.dcm,No Pneumothorax,-0.181549072265625
|
| 134 |
+
train/No Pneumothorax/000132.dcm,No Pneumothorax,-0.16202104091644287
|
| 135 |
+
train/No Pneumothorax/000133.dcm,No Pneumothorax,-0.1954769492149353
|
| 136 |
+
train/Pneumothorax/000134.dcm,Pneumothorax,-0.11729413270950317
|
| 137 |
+
train/No Pneumothorax/000135.dcm,No Pneumothorax,-0.18572622537612915
|
| 138 |
+
train/No Pneumothorax/000136.dcm,No Pneumothorax,-0.17218655347824097
|
| 139 |
+
train/No Pneumothorax/000137.dcm,No Pneumothorax,-0.19926214218139648
|
| 140 |
+
train/No Pneumothorax/000138.dcm,No Pneumothorax,-0.19930341839790344
|
| 141 |
+
train/Pneumothorax/000139.dcm,Pneumothorax,-0.1775619387626648
|
| 142 |
+
train/Pneumothorax/000140.dcm,Pneumothorax,-0.09699171781539917
|
| 143 |
+
train/No Pneumothorax/000141.dcm,No Pneumothorax,-0.1945207715034485
|
| 144 |
+
train/Pneumothorax/000142.dcm,Pneumothorax,-0.05407905578613281
|
| 145 |
+
train/No Pneumothorax/000143.dcm,No Pneumothorax,-0.14209693670272827
|
| 146 |
+
train/Pneumothorax/000144.dcm,Pneumothorax,-0.06982123851776123
|
| 147 |
+
train/Pneumothorax/000145.dcm,Pneumothorax,-0.13382339477539062
|
| 148 |
+
train/No Pneumothorax/000146.dcm,No Pneumothorax,-0.19120937585830688
|
| 149 |
+
train/No Pneumothorax/000147.dcm,No Pneumothorax,-0.15216165781021118
|
| 150 |
+
train/No Pneumothorax/000148.dcm,No Pneumothorax,-0.20141980051994324
|
| 151 |
+
train/No Pneumothorax/000149.dcm,No Pneumothorax,-0.20271122455596924
|
| 152 |
+
train/No Pneumothorax/000150.dcm,No Pneumothorax,-0.16529840230941772
|
| 153 |
+
train/Pneumothorax/000151.dcm,Pneumothorax,-0.15329903364181519
|
| 154 |
+
train/Pneumothorax/000152.dcm,Pneumothorax,-0.08588516712188721
|
| 155 |
+
train/Pneumothorax/000153.dcm,Pneumothorax,-0.15394580364227295
|
| 156 |
+
train/Pneumothorax/000154.dcm,Pneumothorax,-0.03996264934539795
|
| 157 |
+
train/Pneumothorax/000155.dcm,Pneumothorax,-0.13664811849594116
|
| 158 |
+
train/No Pneumothorax/000156.dcm,No Pneumothorax,-0.16998988389968872
|
| 159 |
+
train/Pneumothorax/000157.dcm,Pneumothorax,-0.09838944673538208
|
| 160 |
+
train/Pneumothorax/000158.dcm,Pneumothorax,-0.1137932538986206
|
| 161 |
+
train/No Pneumothorax/000159.dcm,No Pneumothorax,-0.16903436183929443
|
| 162 |
+
train/No Pneumothorax/000160.dcm,No Pneumothorax,-0.19102925062179565
|
| 163 |
+
train/Pneumothorax/000161.dcm,Pneumothorax,-0.13560515642166138
|
| 164 |
+
train/No Pneumothorax/000162.dcm,No Pneumothorax,-0.21745437383651733
|
| 165 |
+
train/No Pneumothorax/000163.dcm,No Pneumothorax,-0.16950178146362305
|
| 166 |
+
train/No Pneumothorax/000164.dcm,No Pneumothorax,-0.06001162528991699
|
| 167 |
+
train/No Pneumothorax/000165.dcm,No Pneumothorax,-0.15253078937530518
|
| 168 |
+
train/No Pneumothorax/000166.dcm,No Pneumothorax,-0.2072521150112152
|
| 169 |
+
train/No Pneumothorax/000167.dcm,No Pneumothorax,-0.1807345747947693
|
| 170 |
+
train/No Pneumothorax/000168.dcm,No Pneumothorax,-0.1796715259552002
|
| 171 |
+
train/No Pneumothorax/000169.dcm,No Pneumothorax,-0.12789452075958252
|
| 172 |
+
train/Pneumothorax/000170.dcm,Pneumothorax,-0.10478848218917847
|
| 173 |
+
train/No Pneumothorax/000171.dcm,No Pneumothorax,-0.15731322765350342
|
| 174 |
+
train/No Pneumothorax/000172.dcm,No Pneumothorax,-0.12667322158813477
|
| 175 |
+
train/Pneumothorax/000173.dcm,Pneumothorax,-0.16081935167312622
|
| 176 |
+
train/No Pneumothorax/000174.dcm,No Pneumothorax,-0.14709264039993286
|
| 177 |
+
train/No Pneumothorax/000175.dcm,No Pneumothorax,-0.1776459813117981
|
| 178 |
+
train/No Pneumothorax/000176.dcm,No Pneumothorax,-0.16818833351135254
|
| 179 |
+
train/No Pneumothorax/000177.dcm,No Pneumothorax,-0.16316360235214233
|
| 180 |
+
train/No Pneumothorax/000178.dcm,No Pneumothorax,-0.1608293354511261
|
| 181 |
+
train/No Pneumothorax/000179.dcm,No Pneumothorax,-0.1174200177192688
|
| 182 |
+
train/No Pneumothorax/000180.dcm,No Pneumothorax,-0.15724217891693115
|
| 183 |
+
train/Pneumothorax/000181.dcm,Pneumothorax,-0.06315004825592041
|
| 184 |
+
train/No Pneumothorax/000182.dcm,No Pneumothorax,-0.18269914388656616
|
| 185 |
+
train/No Pneumothorax/000183.dcm,No Pneumothorax,-0.1433737874031067
|
| 186 |
+
train/No Pneumothorax/000184.dcm,No Pneumothorax,-0.19049185514450073
|
| 187 |
+
train/Pneumothorax/000185.dcm,Pneumothorax,-0.06804823875427246
|
| 188 |
+
train/No Pneumothorax/000186.dcm,No Pneumothorax,-0.20442822575569153
|
| 189 |
+
train/Pneumothorax/000187.dcm,Pneumothorax,-0.1479809284210205
|
| 190 |
+
train/No Pneumothorax/000188.dcm,No Pneumothorax,-0.1297626495361328
|
| 191 |
+
train/Pneumothorax/000189.dcm,Pneumothorax,-0.14443817734718323
|
| 192 |
+
train/No Pneumothorax/000190.dcm,No Pneumothorax,-0.20211485028266907
|
| 193 |
+
train/Pneumothorax/000191.dcm,Pneumothorax,-0.10677635669708252
|
| 194 |
+
train/No Pneumothorax/000192.dcm,No Pneumothorax,-0.15862226486206055
|
| 195 |
+
train/No Pneumothorax/000193.dcm,No Pneumothorax,-0.14113175868988037
|
| 196 |
+
train/No Pneumothorax/000194.dcm,No Pneumothorax,-0.22007161378860474
|
| 197 |
+
train/No Pneumothorax/000195.dcm,No Pneumothorax,-0.10471892356872559
|
| 198 |
+
train/No Pneumothorax/000196.dcm,No Pneumothorax,-0.20787471532821655
|
| 199 |
+
train/No Pneumothorax/000197.dcm,No Pneumothorax,-0.16002091765403748
|
| 200 |
+
train/No Pneumothorax/000198.dcm,No Pneumothorax,-0.17423555254936218
|
| 201 |
+
train/Pneumothorax/000199.dcm,Pneumothorax,-0.0016582608222961426
|
| 202 |
+
train/Pneumothorax/000200.dcm,Pneumothorax,-0.10900384187698364
|
| 203 |
+
train/Pneumothorax/000201.dcm,Pneumothorax,0.029024243354797363
|
| 204 |
+
train/No Pneumothorax/000202.dcm,No Pneumothorax,-0.10805076360702515
|
| 205 |
+
train/No Pneumothorax/000203.dcm,No Pneumothorax,-0.1146092414855957
|
| 206 |
+
train/No Pneumothorax/000204.dcm,No Pneumothorax,-0.19227838516235352
|
| 207 |
+
train/No Pneumothorax/000205.dcm,No Pneumothorax,-0.19742238521575928
|
| 208 |
+
train/No Pneumothorax/000206.dcm,No Pneumothorax,-0.23522859811782837
|
| 209 |
+
train/No Pneumothorax/000207.dcm,No Pneumothorax,-0.17371898889541626
|
| 210 |
+
train/No Pneumothorax/000208.dcm,No Pneumothorax,-0.15263259410858154
|
| 211 |
+
train/No Pneumothorax/000209.dcm,No Pneumothorax,-0.15728116035461426
|
| 212 |
+
train/No Pneumothorax/000210.dcm,No Pneumothorax,-0.13311928510665894
|
| 213 |
+
train/No Pneumothorax/000211.dcm,No Pneumothorax,-0.1066751480102539
|
| 214 |
+
train/No Pneumothorax/000212.dcm,No Pneumothorax,-0.1832524538040161
|
| 215 |
+
train/Pneumothorax/000213.dcm,Pneumothorax,-0.09565180540084839
|
| 216 |
+
train/Pneumothorax/000214.dcm,Pneumothorax,-0.06287646293640137
|
| 217 |
+
train/No Pneumothorax/000215.dcm,No Pneumothorax,-0.07406270503997803
|
| 218 |
+
train/No Pneumothorax/000216.dcm,No Pneumothorax,-0.08096122741699219
|
| 219 |
+
train/Pneumothorax/000217.dcm,Pneumothorax,-0.05490368604660034
|
| 220 |
+
train/No Pneumothorax/000218.dcm,No Pneumothorax,-0.05416452884674072
|
| 221 |
+
train/Pneumothorax/000219.dcm,Pneumothorax,-0.01159369945526123
|
| 222 |
+
train/No Pneumothorax/000220.dcm,No Pneumothorax,-0.1184004545211792
|
| 223 |
+
train/No Pneumothorax/000221.dcm,No Pneumothorax,-0.2137947976589203
|
| 224 |
+
train/Pneumothorax/000222.dcm,Pneumothorax,-0.10213685035705566
|
| 225 |
+
train/Pneumothorax/000223.dcm,Pneumothorax,-0.12993067502975464
|
| 226 |
+
train/No Pneumothorax/000224.dcm,No Pneumothorax,-0.1637454628944397
|
| 227 |
+
train/No Pneumothorax/000225.dcm,No Pneumothorax,-0.1220596432685852
|
| 228 |
+
train/No Pneumothorax/000226.dcm,No Pneumothorax,-0.1765921711921692
|
| 229 |
+
train/Pneumothorax/000227.dcm,Pneumothorax,-0.05948609113693237
|
| 230 |
+
train/No Pneumothorax/000228.dcm,No Pneumothorax,-0.16500937938690186
|
| 231 |
+
train/No Pneumothorax/000229.dcm,No Pneumothorax,-0.2087046504020691
|
| 232 |
+
train/No Pneumothorax/000230.dcm,No Pneumothorax,-0.10890644788742065
|
| 233 |
+
train/No Pneumothorax/000231.dcm,No Pneumothorax,-0.21980196237564087
|
| 234 |
+
train/Pneumothorax/000232.dcm,Pneumothorax,-0.042661070823669434
|
| 235 |
+
train/Pneumothorax/000233.dcm,Pneumothorax,-0.07404029369354248
|
| 236 |
+
train/No Pneumothorax/000234.dcm,No Pneumothorax,-0.19613447785377502
|
| 237 |
+
train/No Pneumothorax/000235.dcm,No Pneumothorax,-0.16667985916137695
|
| 238 |
+
train/Pneumothorax/000236.dcm,Pneumothorax,-0.07997268438339233
|
| 239 |
+
train/No Pneumothorax/000237.dcm,No Pneumothorax,-0.17295250296592712
|
| 240 |
+
train/Pneumothorax/000238.dcm,Pneumothorax,-0.02200949192047119
|
| 241 |
+
train/No Pneumothorax/000239.dcm,No Pneumothorax,-0.1404871940612793
|
| 242 |
+
train/Pneumothorax/000240.dcm,Pneumothorax,-0.045901477336883545
|
| 243 |
+
train/No Pneumothorax/000241.dcm,No Pneumothorax,-0.13813674449920654
|
| 244 |
+
train/No Pneumothorax/000242.dcm,No Pneumothorax,-0.1337980031967163
|
| 245 |
+
train/No Pneumothorax/000243.dcm,No Pneumothorax,-0.09456092119216919
|
| 246 |
+
train/No Pneumothorax/000244.dcm,No Pneumothorax,-0.12863361835479736
|
| 247 |
+
train/No Pneumothorax/000245.dcm,No Pneumothorax,-0.23645243048667908
|
| 248 |
+
train/No Pneumothorax/000246.dcm,No Pneumothorax,-0.17193585634231567
|
| 249 |
+
train/Pneumothorax/000247.dcm,Pneumothorax,-0.0540735125541687
|
| 250 |
+
train/No Pneumothorax/000248.dcm,No Pneumothorax,-0.18047136068344116
|
| 251 |
+
train/No Pneumothorax/000249.dcm,No Pneumothorax,-0.1623515486717224
|
results/kaggle_roc_curve.png
ADDED
|
Git LFS Details
|
results/roc_PNEUMOTHORAX.png
ADDED
|
Git LFS Details
|
src/app.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import logging
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
# Configure logging
|
| 9 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
# Suppress TensorFlow logging
|
| 13 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
| 14 |
+
try:
|
| 15 |
+
import absl.logging
|
| 16 |
+
absl.logging.set_verbosity(absl.logging.ERROR)
|
| 17 |
+
except ImportError:
|
| 18 |
+
pass
|
| 19 |
+
logging.getLogger('tensorflow').setLevel(logging.ERROR)
|
| 20 |
+
|
| 21 |
+
from model import RawImageModel, PrecomputedModel
|
| 22 |
+
|
| 23 |
+
# Global Model Instances
|
| 24 |
+
raw_model = None
|
| 25 |
+
precomputed_model = None
|
| 26 |
+
pos_emb = None
|
| 27 |
+
neg_emb = None
|
| 28 |
+
|
| 29 |
+
# Optimal Threshold from Kaggle validation
|
| 30 |
+
THRESHOLD = -0.1173
|
| 31 |
+
|
| 32 |
+
def load_models():
|
| 33 |
+
global raw_model, precomputed_model, pos_emb, neg_emb
|
| 34 |
+
if raw_model is None:
|
| 35 |
+
logger.info("Loading models...")
|
| 36 |
+
try:
|
| 37 |
+
precomputed_model = PrecomputedModel()
|
| 38 |
+
raw_model = RawImageModel()
|
| 39 |
+
|
| 40 |
+
# Pre-fetch text embeddings
|
| 41 |
+
pos_txt = 'small pneumothorax'
|
| 42 |
+
neg_txt = 'no pneumothorax'
|
| 43 |
+
pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt)
|
| 44 |
+
logger.info("Models loaded.")
|
| 45 |
+
except Exception as e:
|
| 46 |
+
logger.error(f"Failed to load models: {e}")
|
| 47 |
+
raise e
|
| 48 |
+
|
| 49 |
+
def predict(image):
|
| 50 |
+
if image is None:
|
| 51 |
+
return "No image uploaded.", 0.0, "Please upload an image."
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
# Save temp image for model ingestion
|
| 55 |
+
temp_path = "temp_gradio_upload.png"
|
| 56 |
+
image.save(temp_path)
|
| 57 |
+
|
| 58 |
+
# Run Inference
|
| 59 |
+
img_emb = raw_model.compute_embeddings(temp_path)
|
| 60 |
+
score = PrecomputedModel.zero_shot(img_emb, pos_emb, neg_emb)
|
| 61 |
+
score = float(score)
|
| 62 |
+
|
| 63 |
+
# Binary Classification
|
| 64 |
+
prediction = "PNEUMOTHORAX" if score >= THRESHOLD else "NORMAL / NO PNEUMOTHORAX"
|
| 65 |
+
|
| 66 |
+
# Confidence logic (simple distance from threshold)
|
| 67 |
+
# Using sigmoid to map score to probability-like 0-1 for display
|
| 68 |
+
# Note: This is an approximation
|
| 69 |
+
|
| 70 |
+
return prediction, score, f"Raw Score: {score:.4f} (Threshold: {THRESHOLD})"
|
| 71 |
+
|
| 72 |
+
except Exception as e:
|
| 73 |
+
logger.error(f"Prediction failed: {e}")
|
| 74 |
+
return "Error", 0.0, str(e)
|
| 75 |
+
|
| 76 |
+
# Load models at startup
|
| 77 |
+
load_models()
|
| 78 |
+
|
| 79 |
+
# UI Layout
|
| 80 |
+
with gr.Blocks(title="Chest X-Ray Zero-Shot Classifier") as demo:
|
| 81 |
+
gr.Markdown("# 🩻 Zero-Shot Chest X-Ray Classification")
|
| 82 |
+
gr.Markdown("Detect **Pneumothorax** from raw X-ray images using the `google/cxr-foundation` model.")
|
| 83 |
+
|
| 84 |
+
with gr.Row():
|
| 85 |
+
with gr.Column():
|
| 86 |
+
gr.Markdown("### 1. Upload X-Ray")
|
| 87 |
+
input_image = gr.Image(type="pil", label="Upload Image (PNG/JPG/DICOM converted)")
|
| 88 |
+
predict_btn = gr.Button("Analyze Image", variant="primary")
|
| 89 |
+
|
| 90 |
+
with gr.Column():
|
| 91 |
+
gr.Markdown("### 2. Results")
|
| 92 |
+
output_label = gr.Label(label="Prediction")
|
| 93 |
+
output_score = gr.Number(label="Zero-Shot Score")
|
| 94 |
+
output_msg = gr.Textbox(label="Details")
|
| 95 |
+
|
| 96 |
+
gr.Markdown("---")
|
| 97 |
+
gr.Markdown("### Performance Context")
|
| 98 |
+
gr.Markdown("This model uses a **zero-shot** approach. The threshold was calibrated using a local Kaggle dataset.")
|
| 99 |
+
|
| 100 |
+
with gr.Tabs():
|
| 101 |
+
with gr.TabItem("Local Kaggle Benchmark"):
|
| 102 |
+
gr.Image("results/kaggle_roc_curve.png", label="local ROC Curve")
|
| 103 |
+
gr.Markdown("**AUC: 0.88** on 250 local samples.")
|
| 104 |
+
with gr.TabItem("Google Benchmark"):
|
| 105 |
+
gr.Image("results/roc_PNEUMOTHORAX.png", label="Reference ROC")
|
| 106 |
+
|
| 107 |
+
predict_btn.click(predict, inputs=input_image, outputs=[output_label, output_score, output_msg])
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
demo.launch(share=True)
|
src/calculate_threshold.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
from sklearn.metrics import roc_curve
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
# Configure logging
|
| 8 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
def calculate_optimal_threshold(results_path):
|
| 12 |
+
"""
|
| 13 |
+
Calculates optimal threshold using Youden's J statistic.
|
| 14 |
+
"""
|
| 15 |
+
if not os.path.exists(results_path):
|
| 16 |
+
logger.error(f"Results file not found: {results_path}")
|
| 17 |
+
return None
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
df = pd.read_csv(results_path)
|
| 21 |
+
logger.info(f"Loaded {len(df)} predictions from {results_path}")
|
| 22 |
+
|
| 23 |
+
df = df.dropna(subset=['pneumothorax_score'])
|
| 24 |
+
if len(df) == 0:
|
| 25 |
+
logger.error("No valid predictions found.")
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
# Binary Labels
|
| 29 |
+
y_true = (df['true_label'] == 'Pneumothorax').astype(int)
|
| 30 |
+
y_scores = df['pneumothorax_score']
|
| 31 |
+
|
| 32 |
+
fpr, tpr, thresholds = roc_curve(y_true, y_scores)
|
| 33 |
+
|
| 34 |
+
# Youden's J = Sensitivity + Specificity - 1
|
| 35 |
+
# Sensitivity = TPR
|
| 36 |
+
# Specificity = 1 - FPR
|
| 37 |
+
# J = TPR + (1 - FPR) - 1 = TPR - FPR
|
| 38 |
+
j_scores = tpr - fpr
|
| 39 |
+
best_idx = np.argmax(j_scores)
|
| 40 |
+
best_threshold = thresholds[best_idx]
|
| 41 |
+
|
| 42 |
+
logger.info(f"Optimal Threshold (Youden's J): {best_threshold:.4f}")
|
| 43 |
+
logger.info(f"Sensitivity: {tpr[best_idx]:.4f}")
|
| 44 |
+
logger.info(f"Specificity: {1 - fpr[best_idx]:.4f}")
|
| 45 |
+
|
| 46 |
+
return best_threshold
|
| 47 |
+
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.error(f"Failed to calculate threshold: {e}")
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
if __name__ == "__main__":
|
| 53 |
+
results_file = "results/kaggle_predictions.csv"
|
| 54 |
+
calculate_optimal_threshold(results_file)
|
src/create_dummy_image.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import png
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
# Generate a 1024x1024 random noise image to simulate an X-ray
|
| 5 |
+
width = 512
|
| 6 |
+
height = 512
|
| 7 |
+
img = np.random.randint(0, 65535, (height, width)).astype(np.uint16)
|
| 8 |
+
|
| 9 |
+
with open('data/test_xray.png', 'wb') as f:
|
| 10 |
+
writer = png.Writer(width=width, height=height, greyscale=True, bitdepth=16)
|
| 11 |
+
writer.write(f, img.tolist())
|
| 12 |
+
|
| 13 |
+
print("Created data/test_xray.png")
|
src/dicom_utils.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pydicom
|
| 2 |
+
import numpy as np
|
| 3 |
+
import logging
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
def read_dicom_image(file_path):
|
| 9 |
+
"""
|
| 10 |
+
Reads a DICOM file and returns it as a NumPy array (grayscale).
|
| 11 |
+
Handles pixel value scaling and content storage mechanism.
|
| 12 |
+
"""
|
| 13 |
+
try:
|
| 14 |
+
ds = pydicom.dcmread(file_path)
|
| 15 |
+
|
| 16 |
+
# Handle pixel data
|
| 17 |
+
if 'PixelData' not in ds:
|
| 18 |
+
raise ValueError(f"No pixel data found in {file_path}")
|
| 19 |
+
|
| 20 |
+
image_data = ds.pixel_array.astype(float)
|
| 21 |
+
|
| 22 |
+
# Handle RescaleSlope and RescaleIntercept if present (Map to Hounsfield Units or physical values)
|
| 23 |
+
slope = getattr(ds, 'RescaleSlope', 1.0)
|
| 24 |
+
intercept = getattr(ds, 'RescaleIntercept', 0.0)
|
| 25 |
+
image_data = image_data * slope + intercept
|
| 26 |
+
|
| 27 |
+
# Normalize to 0-255 range for consistency with standard image processing
|
| 28 |
+
# Note: This discards absolute physical values but preserves structure for the model
|
| 29 |
+
image_min = np.min(image_data)
|
| 30 |
+
image_max = np.max(image_data)
|
| 31 |
+
|
| 32 |
+
if image_max != image_min:
|
| 33 |
+
image_data = (image_data - image_min) / (image_max - image_min) * 255.0
|
| 34 |
+
else:
|
| 35 |
+
image_data = np.zeros_like(image_data)
|
| 36 |
+
|
| 37 |
+
image_data = image_data.astype(np.uint8)
|
| 38 |
+
|
| 39 |
+
# Handle photometric interpretation (invert if needed)
|
| 40 |
+
# MONOCHROME1 typically means 0 is white, 255 is black (inverse of standard X-ray)
|
| 41 |
+
# We generally want air (black) to be low, bone (white) to be high
|
| 42 |
+
if ds.PhotometricInterpretation == "MONOCHROME1":
|
| 43 |
+
image_data = 255 - image_data
|
| 44 |
+
|
| 45 |
+
return image_data
|
| 46 |
+
|
| 47 |
+
except Exception as e:
|
| 48 |
+
logger.error(f"Error reading DICOM {file_path}: {e}")
|
| 49 |
+
raise
|
src/download.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import wget
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
# Configure logging
|
| 6 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
# Constants
|
| 10 |
+
DATA_DIR = "data"
|
| 11 |
+
URLS = {
|
| 12 |
+
"precomputed_image_embeddings.npz": "https://storage.googleapis.com/healthai-us/encoded-data/nih/radiology/cxr/precomputed_image_embeddings.npz",
|
| 13 |
+
"precomputed_text_embeddings.npz": "https://storage.googleapis.com/healthai-us/encoded-data/nih/radiology/cxr/precomputed_text_embeddings.npz",
|
| 14 |
+
"cxr14_subset_labels.csv": "https://storage.googleapis.com/healthai-us/encoded-data/nih/radiology/cxr/cxr14_subset_labels.csv",
|
| 15 |
+
"sample.png": "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png"
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
def download_file(url, filename, output_dir):
|
| 19 |
+
"""Downloads a file if it doesn't already exist."""
|
| 20 |
+
filepath = os.path.join(output_dir, filename)
|
| 21 |
+
if os.path.exists(filepath):
|
| 22 |
+
logger.info(f"File already exists: {filepath}")
|
| 23 |
+
return
|
| 24 |
+
|
| 25 |
+
logger.info(f"Downloading {filename} from {url}...")
|
| 26 |
+
try:
|
| 27 |
+
wget.download(url, out=filepath)
|
| 28 |
+
print() # Newline after wget bar
|
| 29 |
+
logger.info(f"Downloaded {filename}")
|
| 30 |
+
except Exception as e:
|
| 31 |
+
logger.error(f"Failed to download {filename}: {e}")
|
| 32 |
+
|
| 33 |
+
def main():
|
| 34 |
+
"""Main function to download all required files."""
|
| 35 |
+
if not os.path.exists(DATA_DIR):
|
| 36 |
+
os.makedirs(DATA_DIR)
|
| 37 |
+
logger.info(f"Created directory: {DATA_DIR}")
|
| 38 |
+
|
| 39 |
+
for filename, url in URLS.items():
|
| 40 |
+
download_file(url, filename, DATA_DIR)
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
main()
|
src/evaluate.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from sklearn.metrics import roc_curve, auc
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
def evaluate_predictions(scores, true_labels, diagnosis, output_dir="results"):
|
| 10 |
+
"""Calculates AUC and generates ROC plot."""
|
| 11 |
+
|
| 12 |
+
if not os.path.exists(output_dir):
|
| 13 |
+
os.makedirs(output_dir)
|
| 14 |
+
|
| 15 |
+
fpr, tpr, thresholds = roc_curve(true_labels, scores)
|
| 16 |
+
roc_auc = auc(fpr, tpr)
|
| 17 |
+
|
| 18 |
+
logger.info(f"Diagnosis: {diagnosis}")
|
| 19 |
+
logger.info(f"AUC: {roc_auc:.4f}")
|
| 20 |
+
|
| 21 |
+
# Plot ROC curve
|
| 22 |
+
plt.figure()
|
| 23 |
+
lw = 2
|
| 24 |
+
plt.plot(
|
| 25 |
+
fpr,
|
| 26 |
+
tpr,
|
| 27 |
+
color="darkorange",
|
| 28 |
+
lw=lw,
|
| 29 |
+
label="ROC curve (area = %0.2f)" % roc_auc,
|
| 30 |
+
)
|
| 31 |
+
plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
|
| 32 |
+
plt.xlim([0.0, 1.0])
|
| 33 |
+
plt.ylim([0.0, 1.05])
|
| 34 |
+
plt.xlabel("False Positive Rate")
|
| 35 |
+
plt.ylabel("True Positive Rate")
|
| 36 |
+
plt.title(f"ROC for {diagnosis}")
|
| 37 |
+
plt.legend(loc="lower right")
|
| 38 |
+
|
| 39 |
+
plot_path = os.path.join(output_dir, f"roc_{diagnosis}.png")
|
| 40 |
+
plt.savefig(plot_path)
|
| 41 |
+
logger.info(f"ROC plot saved to {plot_path}")
|
| 42 |
+
|
| 43 |
+
return roc_auc
|
src/evaluate_kaggle.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import logging
|
| 5 |
+
import argparse
|
| 6 |
+
import numpy as np
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
# Configure logging
|
| 10 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
# Suppress TensorFlow logging
|
| 14 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
| 15 |
+
try:
|
| 16 |
+
import absl.logging
|
| 17 |
+
absl.logging.set_verbosity(absl.logging.ERROR)
|
| 18 |
+
except ImportError:
|
| 19 |
+
pass
|
| 20 |
+
import logging
|
| 21 |
+
logging.getLogger('tensorflow').setLevel(logging.ERROR)
|
| 22 |
+
|
| 23 |
+
from model import RawImageModel, PrecomputedModel
|
| 24 |
+
from dicom_utils import read_dicom_image
|
| 25 |
+
from PIL import Image
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
parser = argparse.ArgumentParser(description="Evaluate on Kaggle DICOM Dataset")
|
| 29 |
+
parser.add_argument("--csv", default="data/kaggle/labels.csv", help="Path to labels CSV")
|
| 30 |
+
parser.add_argument("--data-dir", default="data/kaggle", help="Root directory for images if relative paths in CSV")
|
| 31 |
+
parser.add_argument("--output", default="results/kaggle_predictions.csv", help="Output predictions file")
|
| 32 |
+
args = parser.parse_args()
|
| 33 |
+
|
| 34 |
+
# Create output directory
|
| 35 |
+
os.makedirs(os.path.dirname(args.output), exist_ok=True)
|
| 36 |
+
|
| 37 |
+
# Load dataset
|
| 38 |
+
try:
|
| 39 |
+
df = pd.read_csv(args.csv)
|
| 40 |
+
logger.info(f"Loaded {len(df)} records from {args.csv}")
|
| 41 |
+
except Exception as e:
|
| 42 |
+
logger.error(f"Failed to load CSV: {e}")
|
| 43 |
+
return
|
| 44 |
+
|
| 45 |
+
# Check for file column
|
| 46 |
+
file_col = 'file' if 'file' in df.columns else 'dicom_file' # Adapt to potential column names
|
| 47 |
+
if file_col not in df.columns and 'file' not in df.columns:
|
| 48 |
+
# Fallback inspection or error
|
| 49 |
+
logger.error(f"Missing file column in CSV. Found: {df.columns}")
|
| 50 |
+
return
|
| 51 |
+
|
| 52 |
+
# Initialize Models
|
| 53 |
+
try:
|
| 54 |
+
# We need PrecomputedModel for text embeddings (labels)
|
| 55 |
+
precomputed_model = PrecomputedModel()
|
| 56 |
+
|
| 57 |
+
# We need RawImageModel for the images
|
| 58 |
+
raw_model = RawImageModel()
|
| 59 |
+
logger.info("Models loaded successfully.")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
logger.fatal(f"Failed to initialize models: {e}")
|
| 62 |
+
return
|
| 63 |
+
|
| 64 |
+
# Get text embeddings for diagnosis
|
| 65 |
+
diagnosis = 'PNEUMOTHORAX'
|
| 66 |
+
try:
|
| 67 |
+
# Hardcoded prompts matching main.py
|
| 68 |
+
pos_txt = 'small pneumothorax'
|
| 69 |
+
neg_txt = 'no pneumothorax'
|
| 70 |
+
pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt)
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.fatal(f"Failed to get text embeddings: {e}")
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
+
predictions = []
|
| 76 |
+
|
| 77 |
+
# Iterate and predict
|
| 78 |
+
print(f"Running inference for {diagnosis} on {len(df)} images...")
|
| 79 |
+
|
| 80 |
+
temp_path = "temp_inference.png"
|
| 81 |
+
|
| 82 |
+
for _, row in tqdm(df.iterrows(), total=len(df)):
|
| 83 |
+
file_path = row[file_col]
|
| 84 |
+
# Construct full path
|
| 85 |
+
full_path = os.path.join(args.data_dir, file_path) if not os.path.isabs(file_path) else file_path
|
| 86 |
+
|
| 87 |
+
# Check if file exists
|
| 88 |
+
if not os.path.exists(full_path):
|
| 89 |
+
logger.warning(f"File not found: {full_path}")
|
| 90 |
+
predictions.append({
|
| 91 |
+
'file': file_path,
|
| 92 |
+
'true_label': None,
|
| 93 |
+
'pneumothorax_score': None,
|
| 94 |
+
'error': 'File not found'
|
| 95 |
+
})
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
true_label = row.get('label', row.get('PNEUMOTHORAX', 'Unknown'))
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
# 1. Read DICOM
|
| 102 |
+
image_array = read_dicom_image(full_path)
|
| 103 |
+
|
| 104 |
+
# 2. Save as temp PNG (Required by RawImageModel/TF pipeline currently)
|
| 105 |
+
Image.fromarray(image_array).save(temp_path)
|
| 106 |
+
|
| 107 |
+
# 3. Compute Image Embedding
|
| 108 |
+
img_emb = raw_model.compute_embeddings(temp_path)
|
| 109 |
+
|
| 110 |
+
# 4. Compute Zero-Shot Score
|
| 111 |
+
score = PrecomputedModel.zero_shot(img_emb, pos_emb, neg_emb)
|
| 112 |
+
|
| 113 |
+
predictions.append({
|
| 114 |
+
'file': file_path,
|
| 115 |
+
'true_label': true_label,
|
| 116 |
+
'pneumothorax_score': float(score)
|
| 117 |
+
})
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
# logger.warning(f"Failed to process {file_path}: {e}")
|
| 121 |
+
predictions.append({
|
| 122 |
+
'file': file_path,
|
| 123 |
+
'true_label': true_label,
|
| 124 |
+
'pneumothorax_score': None,
|
| 125 |
+
'error': str(e)
|
| 126 |
+
})
|
| 127 |
+
|
| 128 |
+
# Incremental Save every 10 items
|
| 129 |
+
if len(predictions) % 10 == 0:
|
| 130 |
+
pd.DataFrame(predictions).to_csv(args.output, index=False)
|
| 131 |
+
|
| 132 |
+
# Final Save
|
| 133 |
+
results_df = pd.DataFrame(predictions)
|
| 134 |
+
results_df.to_csv(args.output, index=False)
|
| 135 |
+
logger.info(f"Predictions saved to {args.output}")
|
| 136 |
+
|
| 137 |
+
# Cleanup
|
| 138 |
+
if os.path.exists("temp_inference.png"):
|
| 139 |
+
os.remove("temp_inference.png")
|
| 140 |
+
|
| 141 |
+
if __name__ == "__main__":
|
| 142 |
+
main()
|
src/main.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import logging
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
# Suppress TensorFlow and system warnings
|
| 7 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # FATAL
|
| 8 |
+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
| 9 |
+
|
| 10 |
+
import warnings
|
| 11 |
+
warnings.filterwarnings('ignore')
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pandas as pd
|
| 15 |
+
|
| 16 |
+
# Configure logging first
|
| 17 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
# Suppress absl logging from TensorFlow
|
| 21 |
+
try:
|
| 22 |
+
import absl.logging
|
| 23 |
+
absl.logging.set_verbosity(absl.logging.ERROR)
|
| 24 |
+
except ImportError:
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
# Suppress TensorFlow Python logging
|
| 28 |
+
logging.getLogger('tensorflow').setLevel(logging.ERROR)
|
| 29 |
+
|
| 30 |
+
from model import PrecomputedModel, RawImageModel
|
| 31 |
+
from evaluate import evaluate_predictions
|
| 32 |
+
|
| 33 |
+
DIAGNOSIS_PROMPTS = {
|
| 34 |
+
'AIRSPACE_OPACITY': ('Airspace Opacity', 'no evidence of airspace disease'),
|
| 35 |
+
'PNEUMOTHORAX': ('small pneumothorax', 'no pneumothorax'),
|
| 36 |
+
'EFFUSION': ('large pleural effusion', 'no pleural effusion'),
|
| 37 |
+
'PULMONARY_EDEMA': ('moderate pulmonary edema', 'no pulmonary edema'),
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
def main():
|
| 41 |
+
parser = argparse.ArgumentParser(description="Zero-Shot Chest X-Ray Classification")
|
| 42 |
+
parser.add_argument("--diagnosis", type=str, choices=DIAGNOSIS_PROMPTS.keys(), required=True, help="Diagnosis to evaluate")
|
| 43 |
+
parser.add_argument("--data-dir", type=str, default="data", help="Path to data directory")
|
| 44 |
+
parser.add_argument("--raw-image", type=str, help="Path to a raw image file for inference (optional)")
|
| 45 |
+
args = parser.parse_args()
|
| 46 |
+
|
| 47 |
+
# Get prompts
|
| 48 |
+
pos_txt, neg_txt = DIAGNOSIS_PROMPTS[args.diagnosis]
|
| 49 |
+
logger.info(f"Diagnosis: {args.diagnosis}")
|
| 50 |
+
logger.info(f"Positive query: '{pos_txt}'")
|
| 51 |
+
logger.info(f"Negative query: '{neg_txt}'")
|
| 52 |
+
|
| 53 |
+
# Load precomputed model for text embeddings (and image embeddings if no raw image)
|
| 54 |
+
precomputed_model = PrecomputedModel(data_dir=args.data_dir)
|
| 55 |
+
pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt)
|
| 56 |
+
|
| 57 |
+
if args.raw_image:
|
| 58 |
+
# Raw Image Inference Mode
|
| 59 |
+
logger.info(f"Running inference on raw image: {args.raw_image}")
|
| 60 |
+
raw_model = RawImageModel()
|
| 61 |
+
try:
|
| 62 |
+
image_emb = raw_model.compute_embeddings(args.raw_image)
|
| 63 |
+
# image_emb shape is likely (1, 32, 128) or (32, 128)
|
| 64 |
+
# PrecomputedModel.zero_shot expects flattened or (32, 128)
|
| 65 |
+
|
| 66 |
+
score = PrecomputedModel.zero_shot(image_emb, pos_emb, neg_emb)
|
| 67 |
+
logger.info(f"Zero-shot score for {args.raw_image}: {score:.4f}")
|
| 68 |
+
|
| 69 |
+
# Since we only have one image, we can't calculate AUC meaningfully
|
| 70 |
+
# unless we run it against the full validation set which takes time.
|
| 71 |
+
# For this demo, just output the score.
|
| 72 |
+
print(f"Score for {args.diagnosis}: {score}")
|
| 73 |
+
|
| 74 |
+
except Exception as e:
|
| 75 |
+
logger.error(f"Failed to process raw image: {e}")
|
| 76 |
+
sys.exit(1)
|
| 77 |
+
|
| 78 |
+
else:
|
| 79 |
+
# Precomputed Embeddings Evaluation Mode (Full Dataset)
|
| 80 |
+
logger.info("Running evaluation on full precomputed dataset...")
|
| 81 |
+
|
| 82 |
+
# Filter labels for the target diagnosis (0 or 1)
|
| 83 |
+
labels_df = precomputed_model.labels
|
| 84 |
+
target_df = labels_df[labels_df[args.diagnosis].isin([0, 1])][['image_id', args.diagnosis]].copy()
|
| 85 |
+
|
| 86 |
+
image_ids = target_df['image_id'].tolist()
|
| 87 |
+
true_labels = target_df[args.diagnosis].tolist()
|
| 88 |
+
|
| 89 |
+
# Compute scores
|
| 90 |
+
valid_ids, scores = precomputed_model.compute_scores(image_ids, pos_emb, neg_emb)
|
| 91 |
+
|
| 92 |
+
# Filter labels to match valid_ids found in embeddings
|
| 93 |
+
final_labels = []
|
| 94 |
+
for img_id, label in zip(image_ids, true_labels):
|
| 95 |
+
if img_id in valid_ids:
|
| 96 |
+
final_labels.append(label)
|
| 97 |
+
|
| 98 |
+
if not scores:
|
| 99 |
+
logger.error("No valid scores computed. Check embedding match.")
|
| 100 |
+
sys.exit(1)
|
| 101 |
+
|
| 102 |
+
# Evaluate
|
| 103 |
+
evaluate_predictions(scores, final_labels, args.diagnosis)
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
main()
|
src/model.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from huggingface_hub import snapshot_download
|
| 6 |
+
|
| 7 |
+
# Configure logging
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class PrecomputedModel:
|
| 11 |
+
def __init__(self, data_dir="data"):
|
| 12 |
+
self.data_dir = data_dir
|
| 13 |
+
self.image_embeddings = None
|
| 14 |
+
self.text_embeddings = None
|
| 15 |
+
self.labels = None
|
| 16 |
+
self._load_data()
|
| 17 |
+
|
| 18 |
+
def _load_data(self):
|
| 19 |
+
"""Loads precomputed embeddings and labels."""
|
| 20 |
+
img_emb_path = os.path.join(self.data_dir, "precomputed_image_embeddings.npz")
|
| 21 |
+
txt_emb_path = os.path.join(self.data_dir, "precomputed_text_embeddings.npz")
|
| 22 |
+
labels_path = os.path.join(self.data_dir, "cxr14_subset_labels.csv")
|
| 23 |
+
|
| 24 |
+
# Text embeddings are strictly required for Zero-Shot
|
| 25 |
+
if not os.path.exists(txt_emb_path):
|
| 26 |
+
raise FileNotFoundError(f"Missing required text embeddings: {txt_emb_path}")
|
| 27 |
+
|
| 28 |
+
logger.info("Loading precomputed text embeddings...")
|
| 29 |
+
with np.load(txt_emb_path) as data:
|
| 30 |
+
self.text_embeddings = {key: data[key] for key in data}
|
| 31 |
+
|
| 32 |
+
# Image embeddings (Optional, only for benchmarking)
|
| 33 |
+
if os.path.exists(img_emb_path):
|
| 34 |
+
logger.info("Loading precomputed image embeddings...")
|
| 35 |
+
with np.load(img_emb_path) as data:
|
| 36 |
+
self.image_embeddings = {key: data[key] for key in data}
|
| 37 |
+
else:
|
| 38 |
+
logger.warning("Precomputed image embeddings not found. Benchmarking features will be disabled.")
|
| 39 |
+
|
| 40 |
+
# Labels (Optional)
|
| 41 |
+
if os.path.exists(labels_path):
|
| 42 |
+
logger.info("Loading labels...")
|
| 43 |
+
self.labels = pd.read_csv(labels_path)
|
| 44 |
+
else:
|
| 45 |
+
logger.warning("Labels file not found.")
|
| 46 |
+
|
| 47 |
+
def get_diagnosis_embeddings(self, pos_txt, neg_txt):
|
| 48 |
+
"""Retrieves embeddings for positive and negative text queries."""
|
| 49 |
+
if pos_txt not in self.text_embeddings:
|
| 50 |
+
raise ValueError(f"Positive query '{pos_txt}' not found in precomputed embeddings.")
|
| 51 |
+
if neg_txt not in self.text_embeddings:
|
| 52 |
+
raise ValueError(f"Negative query '{neg_txt}' not found in precomputed embeddings.")
|
| 53 |
+
|
| 54 |
+
return self.text_embeddings[pos_txt], self.text_embeddings[neg_txt]
|
| 55 |
+
|
| 56 |
+
def compute_scores(self, image_ids, pos_emb, neg_emb):
|
| 57 |
+
"""Computes zero-shot scores for a list of image IDs."""
|
| 58 |
+
scores = []
|
| 59 |
+
valid_ids = []
|
| 60 |
+
|
| 61 |
+
for img_id in image_ids:
|
| 62 |
+
if img_id not in self.image_embeddings:
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
img_emb = self.image_embeddings[img_id]
|
| 66 |
+
score = self.zero_shot(img_emb, pos_emb, neg_emb)
|
| 67 |
+
scores.append(score)
|
| 68 |
+
valid_ids.append(img_id)
|
| 69 |
+
|
| 70 |
+
return valid_ids, scores
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def compute_image_text_similarity(image_emb, txt_emb):
|
| 74 |
+
"""Computes cosine similarity between image and text embeddings."""
|
| 75 |
+
# Image embedding shape: (1, 32, 128) or (32, 128) flattened?
|
| 76 |
+
# The notebook says: image_emb = np.reshape(image_emb, (32, 128))
|
| 77 |
+
image_emb = np.reshape(image_emb, (32, 128))
|
| 78 |
+
|
| 79 |
+
similarities = []
|
| 80 |
+
for i in range(32):
|
| 81 |
+
# cosine similarity
|
| 82 |
+
similarity = np.dot(image_emb[i], txt_emb) / (np.linalg.norm(image_emb[i]) * np.linalg.norm(txt_emb))
|
| 83 |
+
similarities.append(similarity)
|
| 84 |
+
|
| 85 |
+
return np.max(similarities)
|
| 86 |
+
|
| 87 |
+
@classmethod
|
| 88 |
+
def zero_shot(cls, image_emb, pos_txt_emb, neg_txt_emb):
|
| 89 |
+
"""Computes the zero-shot score (pos_sim - neg_sim)."""
|
| 90 |
+
pos_cosine = cls.compute_image_text_similarity(image_emb, pos_txt_emb)
|
| 91 |
+
neg_cosine = cls.compute_image_text_similarity(image_emb, neg_txt_emb)
|
| 92 |
+
return pos_cosine - neg_cosine
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class RawImageModel:
|
| 96 |
+
def __init__(self):
|
| 97 |
+
self.elixrc_model = None
|
| 98 |
+
self.qformer_model = None
|
| 99 |
+
self._load_model()
|
| 100 |
+
|
| 101 |
+
def _load_model(self):
|
| 102 |
+
"""Loads the TensorFlow model from Hugging Face."""
|
| 103 |
+
try:
|
| 104 |
+
import tensorflow as tf
|
| 105 |
+
import tensorflow_text as text # Registers the ops
|
| 106 |
+
except ImportError:
|
| 107 |
+
raise ImportError("TensorFlow or tensorflow-text is not installed. Use precomputed mode or install them.")
|
| 108 |
+
|
| 109 |
+
logger.info("Checking for GPU acceleration...")
|
| 110 |
+
gpus = tf.config.list_physical_devices('GPU')
|
| 111 |
+
if gpus:
|
| 112 |
+
logger.info(f"Running on GPU: {gpus}")
|
| 113 |
+
else:
|
| 114 |
+
logger.info("Running on CPU. Expect slower inference.")
|
| 115 |
+
|
| 116 |
+
logger.info("Downloading model weights from Hugging Face...")
|
| 117 |
+
model_path = snapshot_download(
|
| 118 |
+
repo_id="google/cxr-foundation",
|
| 119 |
+
allow_patterns=['elixr-c-v2-pooled/*', 'pax-elixr-b-text/*']
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
logger.info("Loading ELIXR-C (Image Encoder)...")
|
| 123 |
+
self.elixrc_model = tf.saved_model.load(os.path.join(model_path, 'elixr-c-v2-pooled'))
|
| 124 |
+
|
| 125 |
+
logger.info("Loading QFormer (Adapter)...")
|
| 126 |
+
self.qformer_model = tf.saved_model.load(os.path.join(model_path, 'pax-elixr-b-text'))
|
| 127 |
+
|
| 128 |
+
def compute_embeddings(self, image_path):
|
| 129 |
+
"""Generates embeddings for a raw image file."""
|
| 130 |
+
import tensorflow as tf
|
| 131 |
+
import png # pypng
|
| 132 |
+
|
| 133 |
+
# Load and preprocess image
|
| 134 |
+
# This follows the notebook's png_to_tfexample logic but simplified or imported
|
| 135 |
+
# For simplicity, implementing the preprocess logic here
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
# Read image using pypng logic or similar
|
| 139 |
+
# Note: The notebook uses pypng to write to BytesIO then TF reads it.
|
| 140 |
+
# We can just read the file directly if it's a PNG.
|
| 141 |
+
with open(image_path, 'rb') as f:
|
| 142 |
+
image_bytes = f.read()
|
| 143 |
+
|
| 144 |
+
# Create TF Example
|
| 145 |
+
example = tf.train.Example()
|
| 146 |
+
features = example.features.feature
|
| 147 |
+
features['image/encoded'].bytes_list.value.append(image_bytes)
|
| 148 |
+
features['image/format'].bytes_list.value.append(b'png')
|
| 149 |
+
serialized_example = example.SerializeToString()
|
| 150 |
+
|
| 151 |
+
# Step 1: ELIXR-C
|
| 152 |
+
elixrc_infer = self.elixrc_model.signatures['serving_default']
|
| 153 |
+
elixrc_output = elixrc_infer(input_example=tf.constant([serialized_example]))
|
| 154 |
+
elixrc_embedding = elixrc_output['feature_maps_0'].numpy() # Shape (1, 8, 8, 1376)
|
| 155 |
+
|
| 156 |
+
# Step 2: QFormer
|
| 157 |
+
# Initialize text inputs with zeros (as we only want image embeddings)
|
| 158 |
+
qformer_input = {
|
| 159 |
+
'image_feature': elixrc_embedding.tolist(),
|
| 160 |
+
'ids': np.zeros((1, 1, 128), dtype=np.int32).tolist(),
|
| 161 |
+
'paddings': np.zeros((1, 1, 128), dtype=np.float32).tolist(),
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
qformer_output = self.qformer_model.signatures['serving_default'](**qformer_input)
|
| 165 |
+
elixrb_embeddings = qformer_output['all_contrastive_img_emb'].numpy() # Shape (1, 32, 128)
|
| 166 |
+
|
| 167 |
+
return elixrb_embeddings
|
| 168 |
+
|
| 169 |
+
except Exception as e:
|
| 170 |
+
logger.error(f"Error computing raw embeddings: {e}")
|
| 171 |
+
raise
|
src/plot_kaggle_roc.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from sklearn.metrics import roc_curve, auc
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
# Configure logging
|
| 9 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
def plot_roc_curve(results_path, output_image_path):
|
| 13 |
+
"""
|
| 14 |
+
Reads predictions CSV, calculates AUC, and plots ROC curve.
|
| 15 |
+
"""
|
| 16 |
+
if not os.path.exists(results_path):
|
| 17 |
+
logger.error(f"Results file not found: {results_path}")
|
| 18 |
+
return
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
df = pd.read_csv(results_path)
|
| 22 |
+
logger.info(f"Loaded {len(df)} predictions from {results_path}")
|
| 23 |
+
|
| 24 |
+
# Filter out errors
|
| 25 |
+
df = df.dropna(subset=['pneumothorax_score'])
|
| 26 |
+
if len(df) == 0:
|
| 27 |
+
logger.error("No valid predictions found.")
|
| 28 |
+
return
|
| 29 |
+
|
| 30 |
+
# Prepare True Labels (Binary)
|
| 31 |
+
# Kaggle Labels: 'Pneumothorax' vs 'No Pneumothorax'
|
| 32 |
+
y_true = (df['true_label'] == 'Pneumothorax').astype(int)
|
| 33 |
+
y_scores = df['pneumothorax_score']
|
| 34 |
+
|
| 35 |
+
# Calculate ROC and AUC
|
| 36 |
+
fpr, tpr, thresholds = roc_curve(y_true, y_scores)
|
| 37 |
+
roc_auc = auc(fpr, tpr)
|
| 38 |
+
logger.info(f"Calculated AUC: {roc_auc:.4f}")
|
| 39 |
+
|
| 40 |
+
# Plot
|
| 41 |
+
plt.figure(figsize=(8, 6))
|
| 42 |
+
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
|
| 43 |
+
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
|
| 44 |
+
plt.xlim([0.0, 1.0])
|
| 45 |
+
plt.ylim([0.0, 1.05])
|
| 46 |
+
plt.xlabel('False Positive Rate')
|
| 47 |
+
plt.ylabel('True Positive Rate')
|
| 48 |
+
plt.title('ROC Curve - Zero-Shot Pneumothorax Classification (Kaggle)')
|
| 49 |
+
plt.legend(loc="lower right")
|
| 50 |
+
plt.grid(True, alpha=0.3)
|
| 51 |
+
|
| 52 |
+
plt.savefig(output_image_path)
|
| 53 |
+
logger.info(f"ROC curve saved to {output_image_path}")
|
| 54 |
+
plt.close()
|
| 55 |
+
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logger.error(f"Failed to plot ROC curve: {e}")
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
results_file = "results/kaggle_predictions.csv"
|
| 61 |
+
output_image = "results/kaggle_roc_curve.png"
|
| 62 |
+
plot_roc_curve(results_file, output_image)
|