Automatic Grading of Diabetic Retinopathy through Deep Learning MAC403 Automatic Grading of Diabetic Retinopathy through Deep Learning Apaar Sadhwani, Leo Tam, and Jason Su Advisors: Robert Chang, Jeff Ullman, Andreas Paepcke November 30, 2016
Problem, Data and Motivation Affects ~100M, many in developed, ~45% of diabetics Make process faster, assist ophthalmologist, self-help Widespread disease, enable early diagnosis/care Given fundus image Rate severity of Diabetic Retinopathy 5 Classes: 0 (Normal), 1, 2, 3, 4 (Severe) Hard classification (may solve as ordinal though) Metric: quadratic weighted kappa, (pred – real)2 penalty Data from Kaggle (California Healthcare Foundation, EyePACS) ~35,000 training images, ~54,000 test images High resolution: variable, more than 2560 x 1920 Possibly reverse the order of these 3 sub-topics
Problem, Data and Motivation Affects ~100M, many in developed, ~45% of diabetics Make process faster, assist ophthalmologist, self-help Widespread disease, enable early diagnosis/care Given fundus image Rate severity of Diabetic Retinopathy 5 Classes: 0 (Normal), 1, 2, 3, 4 (Severe) Hard classification (may solve as ordinal though) Metric: quadratic weighted kappa, (pred – real)2 penalty Data from Kaggle (California Healthcare Foundation, EyePACS) ~35,000 training images, ~54,000 test images High resolution: variable, more than 2560 x 1920 Possibly reverse the order of these 3 sub-topics
Example images Class 0 (normal) Class 4 (severe)
Problem, Data and Motivation Affects ~100M, many in developed, ~45% of diabetics Make process faster, assist ophthalmologist, self-help Widespread disease, enable early diagnosis/care Given fundus image Rate severity of Diabetic Retinopathy 5 Classes: 0 (Normal), 1, 2, 3, 4 (Severe) Hard classification (may solve as ordinal though) Metric: quadratic weighted kappa, (pred – real)2 penalty Data from Kaggle (California Healthcare Foundation, EyePACS) ~35,000 training images, ~54,000 test images High resolution: variable, more than 2560 x 1920 Possibly reverse the order of these 3 sub-topics
Challenges High resolution images Discriminative features small Atypical in vision, GPU batch size issues Discriminative features small Grading criteria: not clear (EyePACS guidelines) learn from data Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance class 0 dominates Too few training examples Image size Batch Size 224 x 224 128 2K x 2K 2 1. EyePACS informs us what features we want our algo to learn
Challenges High resolution images Discriminative features small Atypical in vision, GPU batch size issues Discriminative features small Grading criteria: not clear (EyePACS guidelines) learn from data Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance class 0 dominates Too few training examples Class 0 1 2 3 4 1. EyePACS informs us what features we want our algo to learn
Challenges High resolution images Discriminative features small Atypical in vision, GPU batch size issues Discriminative features small Grading criteria: not clear (EyePACS guidelines) learn from data Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance class 0 dominates Too few training examples 1. EyePACS informs us what features we want our algo to learn Class 2
Challenges High resolution images Discriminative features small Atypical in vision, GPU batch size issues Discriminative features small Grading criteria: not clear (EyePACS guidelines) learn from data Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance class 0 dominates Too few training examples 1. EyePACS informs us what features we want our algo to learn
Challenges High resolution images Discriminative features small Atypical in vision, GPU batch size issues Discriminative features small Grading criteria: not clear (EyePACS guidelines) learn from data Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance class 0 dominates Too few training examples - Mentioned in problem statement - Confirmed with doctors 1. EyePACS informs us what features we want our algo to learn
Challenges High resolution images Discriminative features small Atypical in vision, GPU batch size issues Discriminative features small Grading criteria: not clear (EyePACS guidelines) learn from data Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance class 0 dominates Too few training examples 1. EyePACS informs us what features we want our algo to learn
Challenges High resolution images Discriminative features small Atypical in vision, GPU batch size issues Discriminative features small Grading criteria: not clear (EyePACS guidelines) learn from data Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance class 0 dominates Too few training examples Penalty/Loss Class 1 Truth 2 3 4 - Hard classification non-differentiable - Backprop difficult 1. EyePACS informs us what features we want our algo to learn
Challenges High resolution images Discriminative features small Atypical in vision, GPU batch size issues Discriminative features small Grading criteria: not clear (EyePACS guidelines) learn from data Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance class 0 dominates Too few training examples Penalty/Loss Predict 1 Class 1 Truth 2 3 4 - Hard classification non-differentiable - Backprop difficult 1. EyePACS informs us what features we want our algo to learn
Challenges High resolution images Discriminative features small Atypical in vision, GPU batch size issues Discriminative features small Grading criteria: not clear (EyePACS guidelines) learn from data Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance class 0 dominates Too few training examples Penalty/Loss Predict 2 Class 1 Truth 2 3 4 - Hard classification non-differentiable - Backprop difficult 1. EyePACS informs us what features we want our algo to learn
Challenges High resolution images Discriminative features small Atypical in vision, GPU batch size issues Discriminative features small Grading criteria: not clear (EyePACS guidelines) learn from data Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance class 0 dominates Too few training examples Penalty/Loss Predict 3 Class 1 Truth 2 3 4 - Hard classification non-differentiable - Backprop difficult 1. EyePACS informs us what features we want our algo to learn
Challenges High resolution images Discriminative features small Atypical in vision, GPU batch size issues Discriminative features small Grading criteria: not clear (EyePACS guidelines) learn from data Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance class 0 dominates Too few training examples Penalty/Loss Class 1 Truth 2 3 4 - Hard classification non-differentiable - Backprop difficult 1. EyePACS informs us what features we want our algo to learn
Challenges High resolution images Discriminative features small Atypical in vision, GPU batch size issues Discriminative features small Grading criteria: not clear (EyePACS guidelines) learn from data Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance class 0 dominates Too few training examples Penalty/Loss 2.5 Class 1 Truth 2 3 4 Squared error approximation? Differentiable 1. EyePACS informs us what features we want our algo to learn
Challenges High resolution images Discriminative features small Atypical in vision, GPU batch size issues Discriminative features small Grading criteria: not clear (EyePACS guidelines) learn from data Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance class 0 dominates Too few training examples - Naïve: 3 class problem, or all zeros! - Learn all classes separately: 1 vs All? - Balanced while training - At test time? 1. EyePACS informs us what features we want our algo to learn
Challenges High resolution images Discriminative features small Atypical in vision, GPU batch size issues Discriminative features small Grading criteria: not clear (EyePACS guidelines) learn from data Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance class 0 dominates Too few training examples - Big learning models take more data! - Harness test set? 1. EyePACS informs us what features we want our algo to learn
Conventional Approaches Literature survey: Hand-designed features to pick each component Clean images, small datasets Optic disk, exudate segmentation: fail due to artifacts SVM: poor performance
Conventional Approaches Literature survey: Hand-designed features to pick each component Clean images, small datasets Optic disk, exudate segmentation: fail due to artifacts SVM: poor performance
Our Approach Registration, Pre-processing Convolutional Neural Nets (CNNs) Hybrid Architecture
Step 1: Pre-processing Registration Color correction Hough circles, remove outside portion Downsize to common size (224 x 224, 1K x 1K) Color correction Normalization (mean, variance) Any pics here?
Step 2: CNNs Network in Network architecture Class probabilities Network in Network architecture 7.5M parameters No FC layers, spatial average pooling instead Transfer learning (ImageNet) Variable learning rates Low for “ImageNet” layers Schedule Combat lack of data, over-fitting Dropout, Early stopping Data augmentation (flips, rotation) AvgPool MaxPool (stride2) 3 Conv layers (depth 1024) MaxPool (stride2) 3 Conv layers (depth 384) MaxPool (stride2) 3 Conv layers (depth 256) MaxPool (stride2) 3 Conv layers (depth 96) Input Image
Step 2: CNNs Network in Network architecture Class probabilities Network in Network architecture 7.5M parameters No FC layers, spatial average pooling instead Transfer learning (ImageNet) Variable learning rates Low for “ImageNet” layers Schedule Combat lack of data, over-fitting Dropout, Early stopping Data augmentation (flips, rotation) AvgPool MaxPool (stride2) 3 Conv layers (depth 1024) MaxPool (stride2) 3 Conv layers (depth 384) MaxPool (stride2) 3 Conv layers (depth 256) MaxPool (stride2) 3 Conv layers (depth 96) Input Image
Step 2: CNNs Network in Network architecture Class probabilities Network in Network architecture 2.2M parameters No FC layers, spatial average pooling instead Transfer learning (ImageNet) Variable learning rates Low for “ImageNet” layers Schedule Combat lack of data, over-fitting Dropout, Early stopping Data augmentation (flips, rotation) AvgPool MaxPool (stride2) 3 Conv layers (depth 384, 64, 5) MaxPool (stride2) 3 Conv layers (depth 384) MaxPool (stride2) 3 Conv layers (depth 256) MaxPool (stride2) 3 Conv layers (depth 96) Input Image
Step 2: CNNs Network in Network architecture Class probabilities Network in Network architecture 2.2M parameters No FC layers, spatial average pooling instead Transfer learning (ImageNet) Variable learning rates Low for “ImageNet” layers Schedule Combat lack of data, over-fitting Dropout, Early stopping Data augmentation (flips, rotation) AvgPool MaxPool (stride2) 3 Conv layers (depth 384, 64, 5) Input Image
Step 2: CNNs Network in Network architecture Class probabilities Network in Network architecture 2.2M parameters No FC layers, spatial average pooling instead Transfer learning (ImageNet) Variable learning rates Low for “ImageNet” layers Schedule Combat lack of data, over-fitting Dropout, Early stopping Data augmentation (flips, rotation) AvgPool MaxPool (stride2) 3 Conv layers (depth 384, 64, 5) Input Image
Step 2: CNNs Network in Network architecture Class probabilities Network in Network architecture 2.2M parameters No FC layers, spatial average pooling instead Transfer learning (ImageNet) Variable learning rates Low for “ImageNet” layers Schedule Combat lack of data, over-fitting Dropout, Early stopping Data augmentation (flips, rotation) AvgPool MaxPool (stride2) 3 Conv layers (depth 384, 64, 5) Input Image
Step 2: CNN Experiments What image size to use? What loss function? Strategize using 224 x 224 -> extend to 1024 x 1024 What loss function? Mean squared error (MSE) Negative Log Likelihood (NLL) Linear Combination (annealing) Class imbalance Even sampling -> true sampling MSE conveys cardinality and closely approximates what we really want to do Consider the earlier discussion from challenges here
Step 2: CNN Experiments Image size: 224 x 224 No learning Class probabilities Image size: 224 x 224 AvgPool Loss Function Sampling Result MaxPool (stride2) 3 Conv layers (depth 384, 64, 5) No learning Input Image
Step 2: CNN Experiments Image size: 224 x 224 No learning Class probabilities Image size: 224 x 224 AvgPool Loss Function Sampling Result MSE Fails to learn MaxPool (stride2) 3 Conv layers (depth 384, 64, 5) No learning Input Image
Step 2: CNN Experiments Image size: 224 x 224 No learning Class probabilities Image size: 224 x 224 AvgPool Loss Function Sampling Result MSE Fails to learn MaxPool (stride2) 3 Conv layers (depth 384, 64, 5) No learning Input Image
Step 2: CNN Experiments Image size: 224 x 224 No learning Class probabilities Image size: 224 x 224 AvgPool Loss Function Sampling Result MSE Fails to learn NLL Kappa < 0.1 MaxPool (stride2) 3 Conv layers (depth 384, 64, 5) No learning Input Image
Step 2: CNN Experiments Image size: 224 x 224 No learning Class probabilities Image size: 224 x 224 AvgPool Loss Function Sampling Result MSE Fails to learn NLL Kappa < 0.1 Kappa = 0.29 MaxPool (stride2) 3 Conv layers (depth 384, 64, 5) No learning Input Image
Step 2: CNN Experiments Image size: 224 x 224 0.01x step size Class probabilities Image size: 224 x 224 AvgPool Loss Function Sampling Result NLL (top layers only) Kappa = 0.29 MaxPool (stride2) 3 Conv layers (depth 384, 64, 5) 0.01x step size Input Image
Step 2: CNN Experiments Image size: 224 x 224 0.01x step size Class probabilities Image size: 224 x 224 AvgPool Loss Function Sampling Result NLL (top layers only) Kappa = 0.29 Kappa = 0.42 MaxPool (stride2) 3 Conv layers (depth 384, 64, 5) 0.01x step size Input Image
Step 2: CNN Experiments Image size: 224 x 224 0.01x step size Class probabilities Image size: 224 x 224 AvgPool Loss Function Sampling Result NLL (top layers only) Kappa = 0.29 Kappa = 0.42 Kappa = 0.51 MaxPool (stride2) 3 Conv layers (depth 384, 64, 5) 0.01x step size Input Image
Step 2: CNN Experiments Image size: 224 x 224 0.01x step size Class probabilities Image size: 224 x 224 AvgPool Loss Function Sampling Result NLL (top layers only) Kappa = 0.29 Kappa = 0.42 Kappa = 0.51 MSE Kappa = 0.56 MaxPool (stride2) 3 Conv layers (depth 384, 64, 5) 0.01x step size Input Image
Step 2: CNN Results
Step 2: CNN Results
Computing Setup Amazon EC2: GPU nodes, VPC, Amazon EBS-optimized Single GPU nodes for 224 x 224 (g2.2xlarge) Multi-GPU nodes for 1K x 1K (g2.8xlarge) EBS, Amazon S3 Used Python for processing Torch library (Lua) for training
Computing Setup Data EBS (gp2) Model Expt. 1 or 4 GPU node on EC2
Computing Setup Data 1 Data 2 Model Expt. Snapshot (S3) EBS (gp2) GPU node on EC2
Computing Setup Data 1 Data 2 Model Expt. Master Model 1 Model 2 Snapshot (S3) Data 1 EBS (gp2) Data 2 EBS (gp2) Central Node VPC on EC2 EBS-optimized Model Expt. Master Model 1 GPU node on EC2 Model 2 Model 10 …
Computing Setup ~200 MB/s Data 1 Data 2 Model Expt. Master Model 1 Snapshot (S3) Data 1 EBS (gp2) Data 2 EBS (gp2) Central Node VPC on EC2 EBS-optimized Model Expt. Master ~200 MB/s Model 1 GPU node on EC2 Model 2 Model 10 …
Computing Setup Data 1 Data 2 Master 1 Master 2 Model 1 Model 11 Snapshot (S3) Data 1 EBS (gp2) Data 2 EBS (gp2) Central Node Central Node EBS-optimized VPC on EC2 VPC on EC2 EBS-optimized Master 1 Master 2 Model 1 Model 11 Model 10 Model 20 Model 2 Model 12 … …
Computing Setup g2.2xlarge 1 GPU node on EC2 4 GB GPU memory Batch size: 128 images of 224 x 224 1526 CUDA core on single GPU instance
Computing Setup g2.2xlarge 1 GPU node on EC2 4 GB GPU memory Batch size: 128 images of 224 x 224 !! Batch size: 8 images of 1024 x 1024 !! 1526 CUDA core on single GPU instance
Computing Setup g2.2xlarge g2.8xlarge 1 GPU node on EC2 4 GB GPU memory Batch size: 128 images of 224 x 224 !! Batch size: 8 images of 1024 x 1024 !! g2.8xlarge 4 GPU node on EC2 16 GB GPU memory Data Parallelism Batch size: ~28 images of 1024 x 1024 1526 CUDA core on single GPU instance
Step 3: Hybrid Architecture Class probabilities Fuse Lesion Detector Main Network 64 tiles of 256 x 256 1024 2048
Lesion Detector Web viewer and annotation tool Lesion annotation Extract image patches Train lesion classifier
Viewer and Lesion Annotation
Viewer and Lesion Annotation
Lesion Annotation
Extracted Image Patches
Train Lesion Detector Only hemorrhages so far Positives: 1866 extracted patches from 216 images/subjects Negatives: ~25k class-0 images Pre-processing/augmentation Crop random 256 x 256 image from input, flips Pre-trained Network in Network architecture Accuracy: 99% for Negatives, 76% for Positives
Train Lesion Detector Only hemorrhages so far Positives: 1866 extracted patches from 216 images/subjects Negatives: ~25k class-0 images Pre-processing/augmentation Crop random 256 x 256 image from input, flips Pre-trained Network in Network architecture Accuracy: 99% for Negatives, 76% for Positives
Train Lesion Detector Only hemorrhages so far Positives: 1866 extracted patches from 216 images/subjects Negatives: ~25k class-0 images Pre-processing/augmentation Crop random 256 x 256 image from input, flips Pre-trained Network in Network architecture Accuracy: 99% for Negatives, 76% for Positives
Train Lesion Detector Only hemorrhages so far Positives: 1866 extracted patches from 216 images/subjects Negatives: ~25k class-0 images Pre-processing/augmentation Crop random 256 x 256 image from input, flips Pre-trained Network in Network architecture Accuracy: 99% for Negatives, 76% for Positives
Train Lesion Detector Only hemorrhages so far Positives: 1866 extracted patches from 216 images/subjects Negatives: ~25k class-0 images Pre-processing/augmentation Crop random 256 x 256 image from input, flips Pre-trained Network in Network architecture Accuracy: 99% for Negatives, 76% for Positives
Hybrid Architecture Fuse Main Network Class probabilities Lesion Detector Main Network 64 tiles of 256 x 256 1024 2048
Hybrid Architecture Fuse Main Network Class probabilities 2 x 31 x 31 Lesion Detector 2 Conv layers Main Network 64 tiles of 256 x 256 1024 2048
Hybrid Architecture Fuse Main Network Class probabilities 2 x 31 x 31 2 Conv layers Lesion Detector Main Network 64 tiles of 256 x 256 1024 2048
Training Hybrid Architecture
Training Hybrid Architecture Class probabilities Fuse Lesion Detector Main Network 64 tiles of 256 x 256 1024 2048
Training Hybrid Architecture Class probabilities Backprop Fuse Lesion Detector Main Network 64 tiles of 256 x 256 1024 2048
Training Hybrid Architecture Class probabilities Backprop Fuse Lesion Detector Main Network 64 tiles of 256 x 256 1024 2048
Other Insights Supervised-unsupervised learning Distillation Hard-negative mining Other lesion detectors Attention CNNs Both eyes Ensemble
Clinical Importance 3 class problem True “4” problem Combining imaging modalities (OCT) Longitudinal analysis
Many thanks to… Amazon Web Services Robert Chang Jeff Ullman AWS Educate AWS Cloud Credits for Research Robert Chang Jeff Ullman Andreas Paepcke