Alzheimer’s disease prediction using deep neural networks
This project was carried out as part of the TechLabs “Digital Shaper Program” in Aachen (Winter Term 2021/2022)
Introduction
Alzheimer’s disease (AD), characterised by the progressive impairment of cognitive and memory functions, is the most common type of dementia that often appears in persons over 65 years old. To slow the progression of dementia, timely treatment is crucial, which requires the early diagnosis of AD. Recently, Deep learning methods that can take the intercorrelation between regions into account have become an attractive and fundamental element of computer-assisted analytical techniques and has been widely employed for the automated diagnosis and analysis of neuropsychiatric disorders.
This project focused on the classification and prediction of Alzheimer’s disease. We wanted to create a highly accurate model that predicts the stage of Alzheimers to help doctors in diagnosis based on MRI images. If MRI images are not available yet, give the doctor a first assessment of the dementia risk for the patient based on patient data.
We then explored two tasks:
- To date, the analysis of neuro-imaging data, such as those obtained from magnetic resonance imaging, has primarily been performed by experts such as radiologists and physicians, thus requiring a high degree of specialisation. We explored different CNN models() to automatically predict the stage of Alzheimer’s disease based on the MRI image of a patient. There are four stages: not demented, very mildly demented, mildly demented, moderately demented.
2. Predict the risk of a person developing dementia based on possible correlations with certain factors such as: sex, age, socioeconomic status, years of education. We explored several machine learning models to classify the risk.
Method
To solve these tasks it was decided to try out several data models for each prediction task. At first data preprocessing and data exploration was performed. The process and the whole solution pipeline could be described in the following points:
- Data preprocessing
- Data exploration
- Prediction of stages of an Alzheimer disease based on MRI scans
- Prediction of a risk of developing Dementia
1. Data preprocessing — Task 1: This task of predicting stage of an Alzheimer’s disease based on MRI scans is an image recognition task. For data preprocessing, class labels’ names were transferred from text labels to numbers and one hot encoded. There are 4 different classes of images differing in the stage of dementia. We had around 6400 MRI images, that were split in train and test parts with test size of 0.2. Before training image data augmentation was performed. The images were randomly rotated by 30°, zoomed, shifted in height and width and finally horizontally and vertically shifted. — Task 2: The second task is a basic binary classification task, where the data were available in a csv file. Using the trained classifier, one can predict the probability of person developing dementia based on several features. The class labels are demented vs non-demented, the class converted was dropped as it was highly unrepresented and task was viewed as a binary classification problem. Other binary features were converted to binary values and other features were min-max scaled to values from 0 to 1. Rows with NA values were dropped as their number was insignificant compared to the size of the dataset. We remained with around 350 data entries, that were split in train and test parts with test size of 0.2.
2. Data exploration — Task 1:
To explore the data an average image and standard deviation image was computed. The MRI scans were made for various brain slices, however it is possible to notice differences.
Representation of different classes was also explored. It was found that Moderately Demented and Mild Demented classes are higly unrepresented:
- Non Demented: 3200 pictures
- Very Mild Demented: 2240 pictures
- Moderate Demented: 64 pictures
- Mild Demented: 896 pictures.
- Task 2:
For the second task correlation between different features was explored.
- One could notice higher positive correlation of 0.56 between gender (M/F) and Atlas Scaling Factor (ASF). Very high negative correlation can be observed between feature Estimated Total Intracranial Volume (eTIV) and Atlas Scaling Factor (ASF); lower between Clinical Dementia Rating (CRD) and Mini Mental State Examination (MMSE) and also between Normalised Whole Brain Volume (nWBV) and Age.
3. Prediction of an Alzheimer disease based on MRI scans
To predict a probability of a stage of the Alzheimer’s disease several models were utilised. As stated the task was a multi-class image recognition task. For that several architectures of convolutional deep neural networks pre-trained in ImageNet dataset were utilised. The task is then approached as transfer learning, where only several output layers of the networks are trained on the new data. The following architectures were used:
— AlexNet — Inception — Xception — MobileNetV2 — Resnet50 — VGG16
They each return a probability of an image coming from each class and assign the image to the class with the highest probability, every architecture is shortly described in Results.
4. Prediction of a risk of developing dementia The second project focusing on the prediction of a risk of developing dementia is a binary classification problem. For that the following classifiers were used: — Fully Connected Neural network — Random Forest — Xboost — KNN
Based on given features, models predict if the person is demented or not.
Project Result
Task 1:
The data consists of MRI images. The data has four classes of images both in training as well as a testing set: 1. Mild Demented 2. Moderate Demented 3. Non Demented 4. Very Mild Demented
Different Models that have been used for training and prediction and their accuracies are described below:
- AlexNet:
We use the activation function ReLu for each layer and the function Softmax for the output layer.
The following shows a visualisation of the model’s metrics:
The following is the confusion matrix:
The accuracy of the model stops improving with 0.001 difference after 9 epochs and the final accuracy result applying it to the testing set is around 50%, with the following classification report:
The Kaggle notebook can be viewed here.
- Xception:
With a modified depth-wise separable convolution, this model is said to be better than prior Inception models in image classification tasks. It is a point-wise convolution followed by a depth-wise convolution instead of the other way around.
The accuracy of the model stops improving after 19 epochs and the final accuracy result applying it to the testing set is only around 34%, with the following classification report:
The Kaggle notebook can be viewed here.
- InceptionV3:
InceptionV3 is a pre-trained convolutional neural network model that is already trained on images from the ImageNet database.
The activation function Relu is used for the dense layers, except for the output layer where the activation function softmax is used. The optimiser implements the Adam algorithm with the learning rate 0.001. Using early stopping, the training of the model was stopped at epoch 20 when the validation accuracy is found to not be improving from 0.81309. The following shows a visualisation of the model’s metrics:
When the testing dataset is used, the model is found to have the F1 score of 45.07% and an accuracy of 50.70%. The following shows the confusion matrix:
The Kaggle notebook can be viewed here.
The web application was built by using streamlit and can be accessed through this link.
You can find the source code for the streamlit web application here.
- MobileNetV2
MobileNetV2 is a pre-trained convolutional neural network model that aims to perform well on mobile devices. It is also trained using images from the ImageNet database.
Similarly to the model that uses InceptionV3, the activation function Relu is also used for the dense layers. Softmax is chosen as the activation function for the output layer. The optimiser also implements the Adam algorithm with the learning rate 0.001. The training of the model was stopped after 13 epochs because early stopping was implemented. This is because the validation accuracy is found to not be improving from 0.82441. The following shows a visualisation of the model’s metrics:
The model has an F1 score of 35.84% and an accuracy of 39.38% when the testing dataset is used. The following image depicts the corresponding confusion matrix:
The Kaggle notebook can be viewed here.
The web application was build using streamlit. Click here to view the web application.
The source code for the web application can be found here.
- ResNet 50 :
ResNet stands for stands for Residual Network which is a pre-trained convolutional neural network model that is already trained on images from the ImageNet database.ResNet has many variants that run on the same concept but have different numbers of layers. Resnet50 is used to denote the variant that can work with 50 neural network layers.
The activation function Relu is used for the dense layers, except for the output layer where the activation function softmax is used. The optimiser implements the Adam algorithm with the learning rate 0.001. Using early stopping, the training of the model was stopped at epoch 24 when the validation accuracy is found to not be improving from 0.82754. The following shows a visualisation of the model’s metrics:
The model has an F1 score of 39.45 % and an accuracy of 52.73 % when the testing dataset is used. The following image depicts the corresponding confusion matrix:
The Kaggle notebook can be viewed here.
The web application was build using streamlit. Click here to view the web application.
The source code for the web application can be found here.
- VGG16 :
VGG16 stands for “Visual Geometry Group” and 16 implies that architecture has 16 layers. The architecture was introduced in 2014 and trained on ImageNet dataset. The model made improvements over AlexNet architecture in the ImageNet Large Scale Visual Recognition Challenge (ILSVRC) in the year 2014 by replacing large kernel-sized filters with multiple 3x3 kernel-sized filters. In this task transfer learning was used where the output layers of the original model are not included and therefore new output layers for prediction are added and trained. Average pooling was used.
The model was compiled with the optimiser Adam with the learning rate 0.001 and categorical cross-entropy as loss. Using early stopping, the training of the model was stopped at epoch 32 when the validation accuracy is found to not be improving from 0.9005. The following shows a visualisation of the model’s metrics:
The model has an F1 score of 42.65 % and an accuracy of 53.91 % on the test dataset. The it produces the following confusion matrix:
The Kaggle notebook can be viewed here.
The web application was build using streamlit. Click here to view the web application.
The source code for the web application can be found here.
Task 2:
In this tasks, the data consists in the following 9 features: — Group –> Class — Age –> Age — EDUC –> Years of Education — SES –> Socioeconomic Status — MMSE –> Mini Mental State Examination — CDR –> Clinical Dementia Rating — eTIV –> Estimated total intracranial volume — nWBV –> Normalise Whole Brain Volume — ASF –> Atlas Scaling Factor
Different Models that have been used for training and prediction and their accuracies are described below:
- Random Forest:
From a total of 9 features we select the most significant and important ones:
We find then that the CDR and MMSE features are the most important and we can use them to make the classification of Alzheimer risk prone person or not. We find the following accuracy score and confusion matrix.
Accuracy score:
Confusion matrix:
The Kaggle notebook can be viewed [here] (https://www.kaggle.com/code/khaoulabc/alzheimer-features-random-forest).
The web application was built by using streamlit and can be accessed through this link.
You can find the source code for the streamlit web application here.
- Fully Connected Neural Network :
A four layer fully connected network is used to train with Relu activation function for top layers and softmax activation for output layer for Binary classification. All the features are used as input.
A perfect F1 score and accuracy was achieved.
Confusion matrix :
The Kaggle notebook can be viewed here.
The web application was built using streamlit. Click here to view the web application.
The source code for the web application can be found here.
- KNN:
We first use all features and try to find the best parameter K for the model.
The cross-validation accuracy is 0.92, but K would be 1. We assume that CDR and MMSE are too highly correlated with the dependent variable and want to see, what the parameter would look like without these variables. Then the best K would be with 7 neighbours while having a cross-validation accuracy score of 0.69.
The Kaggle notebook can be viewed here.
The web application was built by using streamlit and can be accessed through this link.
You can find the source code for the streamlit web application here.
- XGBoost
XGBoost is a software library that implements gradient boosted decision trees that is optimised for speed and performance. It is an approach where new models are added to predict the residual errors of prior models and to make final prediction. It utilises gradient descent to minimise loss when adding new models. For this task XGB Binary Classifier with logistic loss was used.
Randomised search was utilised to find the best model, F1 score of the best model was 96.9 %.
Feature importance was also examined, the most important feature is Clinical Dementia Risk (CDR).
The Kaggle notebook can be viewed here.
The web application was built by using streamlit and can be accessed through this link.
You can find the source code for the streamlit web application here.
Conclusion
Possibilities to improve
For the first task, all models have difficulty predicting the minority classes (MildDemented and ModerateDemented). This is due to the dataset being imbalanced. A possible solution to this problem would be to implement oversampling or under-sampling. Another possible method to improve the model regarding this problem is to combine the data in the classes MildDemented, ModerateDemented and VeryMildDemented. This would transform the problem into a binary classification problem with the labels Demented and NonDemented.
There is also room for improvement for the deployment of our models. Instead of creating a web application for each model, we could deploy all the models into a single web application. This web application would then allow the user the option to choose their preferred model using a drop down menu.
Final Thoughts
Through this project, we have learned:
- Data preprocessing
- Image augmentation
- Building and fine tuning deep learning models using different convolutional neural networks
- Preventing overfitting or under-fitting by implementing early stopping
- Measuring the performance of our models
- Deploying our models by creating web applications
- Teamwork and communication, creating a pitch, and much more!
Therefore, we would like to thank TechLabs Aachen for this opportunity to learn and our mentor for guiding us through this project.
Team Members
Khaoula Ben Chaabane
Shree Harsha Shivashankar Bhat — LinkedIn; GitHub; Kaggle
Natálie Brožová — LinkedIn
Kristie Lim
Yihan Xu
Mentor
Hachem Sfar LinkedIn, GitHub, Kaggle
TechLabs Aachen e.V. reserves the right not to be responsible for the topicality, correctness, completeness or quality of the information provided. All references are made to the best of the authors’ knowledge and belief. If, contrary to expectation, a violation of copyright law should occur, please contact journey.ac@techlabs.org so that the corresponding item can be removed.