Stroke Prediction

By: Yusuf Abdulla

Introduction

Strokes are the 2nd highest cause of deaths around the world at approximately 11% of all deaths, according to the World Health Organization. According to CDC, every year around 600,000 people in the United States have a stroke for the first time out of around 800,000 people who have a stroke. This means that most of the strokes every year are new strokes which inflicted the patients that year. Moreover, stroke-related costs in the United States were almost $50 billion dollars between 2014 and 2015. Our goal is to analyze the data of patients to look for patterns, visualize them, and learn from them to ultimately try to predict whether a patient has a stroke or not. In the real world, this model would help predict whether a patient will get a stroke or not so doctors could start to treat the patients accordingly and prepare them in case the patient does end up getting a stroke. The model will be trained to predict the patients who already had a stroke based on the information of patients but this doesn't take away from the fact that the model can accurately predict whether a patient will have a stroke or not before they actually have a stroke because patient information constantly changes and feeding the model updated patient information once in a while will help us to check if the "conditions of having a stroke" have been met or not. As a result, we can get an early start in treating the patient before they have a stroke which will save lots of patients and also lots of money. While the model may not be 100% accurate, getting a model with high accuracy will be very helpful for our goal. If a small portion of patients would not have gotten a stroke anyways, that is completely fine. Our goal is to find and treat as many patient as we can before they have a stroke. This tutorial will go through the full data science pipeline and hopefully we all pick up valuables insights along the way!

Getting The Data

When looking for my dataset, I did not want it to be so complicated that it exceeds the scope of a tutorial but I also did not want it to be too simple. I ended up finding a nice dataset on Kaggle which had information about patients as well as whether they had a stroke or not. I downloaded the csv file found here: https://www.kaggle.com/fedesoriano/stroke-prediction-dataset. Fortunately for us the pandas library has a function to read in a csv file and create a DataFrame.

In [1]:
import pandas as pd
import numpy as np

data = pd.read_csv('healthcare-dataset-stroke-data.csv')
data.head()
Out[1]:
id gender age hypertension heart_disease ever_married work_type Residence_type avg_glucose_level bmi smoking_status stroke
0 9046 Male 67.0 0 1 Yes Private Urban 228.69 36.6 formerly smoked 1
1 51676 Female 61.0 0 0 Yes Self-employed Rural 202.21 NaN never smoked 1
2 31112 Male 80.0 0 1 Yes Private Rural 105.92 32.5 never smoked 1
3 60182 Female 49.0 0 0 Yes Private Urban 171.23 34.4 smokes 1
4 1665 Female 79.0 1 0 Yes Self-employed Rural 174.12 24.0 never smoked 1

The dataset has the following features:

  • 1) id: unique identifier
  • 2) gender: "Male", "Female" or "Other"
  • 3) age: age of the patient
  • 4) hypertension: 0 if the patient doesn't have hypertension, 1 if the patient has hypertension
  • 5) heart_disease: 0 if the patient doesn't have any heart diseases, 1 if the patient has a heart disease
  • 6) ever_married: "No" or "Yes"
  • 7) work_type: "children", "Govt_jov", "Never_worked", "Private" or "Self-employed"
  • 8) Residence_type: "Rural" or "Urban"
  • 9) avg_glucose_level: average glucose level in blood
  • 10) bmi: body mass index
  • 11) smoking_status: "formerly smoked", "never smoked", "smokes" or "Unknown"*
  • 12) stroke: 1 if the patient had a stroke or 0 if not

Note: "Unknown" in smoking_status means that the information is unavailable for this patient

Tidying The Data

We start off by dropping the ID column since it is has no real value to the patient information or whether the patient will have a stroke or not.

In [2]:
data.drop(columns=['id'], inplace=True)
data.head()
Out[2]:
gender age hypertension heart_disease ever_married work_type Residence_type avg_glucose_level bmi smoking_status stroke
0 Male 67.0 0 1 Yes Private Urban 228.69 36.6 formerly smoked 1
1 Female 61.0 0 0 Yes Self-employed Rural 202.21 NaN never smoked 1
2 Male 80.0 0 1 Yes Private Rural 105.92 32.5 never smoked 1
3 Female 49.0 0 0 Yes Private Urban 171.23 34.4 smokes 1
4 Female 79.0 1 0 Yes Self-employed Rural 174.12 24.0 never smoked 1

Next, we get the generall overview of the data as well as the statistical overview. These give us a general idea of the shape of dataset, the datatypes of the columns, and the descriptive statistics for the numeric columns. They also help us eyeball any missing data as we will see below.

In [3]:
# Overall description of data
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5110 entries, 0 to 5109
Data columns (total 11 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   gender             5110 non-null   object 
 1   age                5110 non-null   float64
 2   hypertension       5110 non-null   int64  
 3   heart_disease      5110 non-null   int64  
 4   ever_married       5110 non-null   object 
 5   work_type          5110 non-null   object 
 6   Residence_type     5110 non-null   object 
 7   avg_glucose_level  5110 non-null   float64
 8   bmi                4909 non-null   float64
 9   smoking_status     5110 non-null   object 
 10  stroke             5110 non-null   int64  
dtypes: float64(3), int64(3), object(5)
memory usage: 439.3+ KB
In [4]:
# Statistical summary of dataset
data.describe()
Out[4]:
age hypertension heart_disease avg_glucose_level bmi stroke
count 5110.000000 5110.000000 5110.000000 5110.000000 4909.000000 5110.000000
mean 43.226614 0.097456 0.054012 106.147677 28.893237 0.048728
std 22.612647 0.296607 0.226063 45.283560 7.854067 0.215320
min 0.080000 0.000000 0.000000 55.120000 10.300000 0.000000
25% 25.000000 0.000000 0.000000 77.245000 23.500000 0.000000
50% 45.000000 0.000000 0.000000 91.885000 28.100000 0.000000
75% 61.000000 0.000000 0.000000 114.090000 33.100000 0.000000
max 82.000000 1.000000 1.000000 271.740000 97.600000 1.000000

We observe that all of the columns have 5110 non-null entries while the bmi has only 4909 entries. There are many ways of dealing with missing data. The first option is to drop the feature entirely but we may not want to do that because every feature may be important to us for our predictions. The next option we can try is to delete all instances where a column has a null value. This is done below and we end up with 4909 rows for our whole dataset. This is equivalent to throwing away all the information of any patient whose bmi was not recorded. Again, this may not be the best option here.

In [5]:
nonull_data = data.dropna()
len(nonull_data)
Out[5]:
4909

A third option we have is to impute the data. This means filling in missing data with a good estimate such as the mean or median of the feature. While this may not give us the actual values, it allows us to not delete whole features or instances while maintaining the real-life accuracy of the imputed feature(s). This sounds like a good option especially when the missing data is not that much. Now we have to decide how to fill in the missing data so we start off with creating a boxplot of the feature.

In [6]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 3))
sns.boxplot(x=data['bmi'])
plt.show()

We see that there are a lot of outliers, represented by the diamond shaped points, some of which extend way beyond the upper bound. Out of the mean and median, the median is less prone to change from outliers so we choose that as the estimator we use to impute the missing data.

  • Note: We could have also looked at the descriptive statistics displayed above and seen that the 75% percentile was 33.1 while the max was 97.6 and came to a similar conclusion but sometimes it is better to visualize the data so you can better see what is going on.
In [7]:
from sklearn.impute import SimpleImputer

imputer = SimpleImputer(strategy='median')
imputer.fit(np.array(data['bmi'].values.tolist()).reshape(-1, 1))
data['bmi'] = imputer.transform(np.array(data['bmi'].values.tolist()).reshape(-1, 1))

len(data['bmi'])
Out[7]:
5110

Lastly, the smoking_status feature has missing values but is already represented for us as "Unknown" and we leave it at that because it is text data unlike the bmi feature. The amount of missing data in the feature is a lot as seen below but we will leave it and treat those instances as where it could be any of the other categories for their smoking status.

In [8]:
len(data[data['smoking_status'] == 'Unknown'])
Out[8]:
1544

Exploring The Data

The first step in exploring data is to get a look at the distributions (for numerical data) and counts (for categorical data) of the features and then to analyze them and come up with patterns and insights. A lot of the time when we write code and output plain numbers we don't really see or get to visualize what actually is happening even if we may understand what is going on. Plotting the data and playing around with various combinations gives us a natural feel for what's going on. After all, our most important sense is sight.

Numerical Data

Let's start with numerical data. The three numerical features we have are age, avg_glucose_level, and bmi. Although other features such as stroke contain numeric values (1/0), they are considered categorical because they are not continuous variables and are more represented as a yes/no type of feature. We first get the numerical features and plot their distributions. We add a kde (kernel density estimate) line to give us a smooth representation of the distribution.

In [9]:
numericals = ['age', 'avg_glucose_level', 'bmi']
In [10]:
plt.figure(figsize=(20, 6))

for i, numerical in enumerate(numericals):
    ax = plt.subplot(1, 3, i+1)
    sns.histplot(data=data, x=numerical, kde=True)
    
    # Setting title, labels
    plt.title('Distribution of {}'.format(numerical), size=14)
    plt.xlabel(numerical, size=12)
    plt.ylabel('Count', size=12)
    
plt.show()

The age of patients seem to be pretty evenly distributed between 0 and 80 which is nice. The avg_glucose_level seems to be bimodial, a local maximum in the 200s. The bmi distribution is unimodial with a skewness to the right. Now let's look at the same distributions but distinguish between the ones that had a stroke and the ones that didn't. This allows us to see the distributions of the features for patients who had a stroke in comparison to the overall distributions.

In [11]:
plt.figure(figsize=(20, 6))

for i, numerical in enumerate(numericals):
    ax = plt.subplot(1, 3, i+1)
    sns.histplot(data=data, x=numerical, hue='stroke', kde=True)
    
    # Setting title, labels
    plt.title('Distribution of {} distinguished by stroke'.format(numerical), size=14)
    plt.xlabel(numerical, size=12)
    plt.ylabel('Count', size=12)
    
plt.show()

We seem to find interesting patterns in our newly plotted distributions, especially in age where the frequency of having a stroke roughly goes up with age. The bmi seems to follow the shape of the overall distribution and the avg_glucose_level too to a tiny degree. Let us take a closer look at the seperate distributions distinguished by stroke.

In [12]:
plt.figure(figsize=(20, 7))

for i, numerical in enumerate(numericals):
    ax = plt.subplot(1, 3, i+1)
    sns.violinplot(data=data, x='stroke', y=numerical)
    
    # Setting title, labels
    plt.title('Distribution of {} distinguished by stroke'.format(numerical), size=14)
    plt.xlabel('Stroke', size=12)
    plt.ylabel(numerical, size=12)
    
plt.show()

The biggest takeaway here is that age is a pretty big factor on whether a patient had a stroke or not. Almost all patients who had a stroke were above the age of 40. With age being pretty equally distributed, we can confidently say that age is a big factor on whether a patient had a stroke or not. We will test this hypothesis along with the other features later in this tutorial. In the meantime, for what it is worth, we can look at a correlation matrix between the numerical features. This tests the linear relationship between variables and although we will be using a logistic and not a linear regression later, it is still worth to take a look and take what we can get from it.

In [13]:
plt.figure(figsize=(9, 6))

corr_matrix = data.corr()
# A heatmap of the correlation matrix. The darker the color the stronger the linear relationship.
sns.heatmap(corr_matrix, cmap='Blues')

plt.show()
In [14]:
corr_matrix['stroke'].sort_values(ascending=False)
Out[14]:
stroke               1.000000
age                  0.245257
heart_disease        0.134914
avg_glucose_level    0.131945
hypertension         0.127904
bmi                  0.036110
Name: stroke, dtype: float64
  • As we guessed, age has the strongest linear correlation with a patient having a stroke or not. We will test the logistic relationships later.

Categorical Data

Now for the categorical data. All the other features which were not numerical are pat of the categorical data. Since we cannot plot the distribution of these features like we did for the numerical ones, we will instead use a countplot. A countplot simply plots the count of each category in a barplot form which allows us to view the distribution of the categories. I added the percentages to make it easier to read the plots and relative frequency of categories. We also want to check the number of categories each feature has in case one of them has a very large amount, which would make the countplots look very messy.

In [15]:
categoricals = ['gender', 'hypertension', 'heart_disease', 'ever_married', 'work_type', 'Residence_type', 'smoking_status', 'stroke']
In [16]:
# printing number of categories in each feature
for categorical in categoricals:
    print(categorical, ':', len(data[categorical].value_counts()))
gender : 3
hypertension : 2
heart_disease : 2
ever_married : 2
work_type : 5
Residence_type : 2
smoking_status : 4
stroke : 2
  • Each feature has a modest number of categories so we should be ready to plot them.
In [17]:
plt.figure(figsize=(12, 24))

for i, categorical in enumerate(categoricals):
    ax = plt.subplot(4, 2, i+1)
    sns.countplot(data=data, x=categorical)
    
    # Setting title, labels
    plt.title('Frequency of {}'.format(categorical), size=14)
    plt.xlabel('', size=12)
    plt.ylabel('Count', size=12)
    
    # Showing the percentages
    for p in ax.patches:
        x=p.get_bbox().get_points()[:,0]
        y=p.get_bbox().get_points()[1,1]
        ax.annotate('{:.1f}%'.format(100.*y/len(data)), (x.mean(), y), 
                    ha='center', va='bottom')
    
plt.show()

Some observations we can take from these plots are that there are considerably more females than than males and that all of the features do not have equally distributed categories except Residence_type. Now we will look at the counts of these features split by whether the patient had a stroke or not, similar to what we did earlier.

In [18]:
plt.figure(figsize=(12, 20))

for i, categorical in enumerate(categoricals):
    ax = plt.subplot(4, 2, i+1)
    sns.countplot(data=data, x=categorical, hue='stroke')
    
    # Setting title, labels
    plt.title('Frequency of {} split by stroke'.format(categorical), size=14)
    plt.xlabel('', size=12)
    plt.ylabel('Count', size=12)
    
    # Showing the percentages
    for p in ax.patches:
        x=p.get_bbox().get_points()[:,0]
        y=p.get_bbox().get_points()[1,1]
        ax.annotate('{:.1f}%'.format(100.*y/len(data)), (x.mean(), y), 
                    ha='center', va='bottom')
    
plt.show()

There are a lot of interesting takeaways we can get from these plots but I will focus on the one I found most interesting which is work_type. Here are the statistics: roughly 8% of self-employed workers had a stroke, followed by 5% of private workers, then 4% of government workers. Although I am not a health expert, what I got from this is that stress plays a role in whether a person had a stroke or not. Typically self-employed workers have a lot of responsibilities and consequently more stress while government jobs in general are less demanding than jobs in private companies. Of course this is just speculation but it was interesting to find these percentages.

Hypothesis Testing

We will use hypothesis testing to calculate the p-values of our numerical features to test whether they are significant or not, meaning if they have a significant impact on the output or not. We will be using a logistic model since we are using continuous data to predict a classification (1/0).

In [19]:
import statsmodels.api as sm

# Using statsmodels for hypothesis testing
logit_model = sm.Logit(data.stroke, data[numericals])
logit_result = logit_model.fit()
logit_result.summary()
Optimization terminated successfully.
         Current function value: 0.190841
         Iterations 8
Out[19]:
Logit Regression Results
Dep. Variable: stroke No. Observations: 5110
Model: Logit Df Residuals: 5107
Method: MLE Df Model: 2
Date: Wed, 19 May 2021 Pseudo R-squ.: 0.02008
Time: 23:32:25 Log-Likelihood: -975.20
converged: True LL-Null: -995.19
Covariance Type: nonrobust LLR p-value: 2.085e-09
coef std err z P>|z| [0.025 0.975]
age 0.0425 0.003 14.339 0.000 0.037 0.048
avg_glucose_level 0.0012 0.001 0.963 0.335 -0.001 0.004
bmi -0.1873 0.008 -22.886 0.000 -0.203 -0.171
In [20]:
# Getting p-values from the table
p_values = logit_result.summary2().tables[1]['P>|z|']
p_values.round(3)
Out[20]:
age                  0.000
avg_glucose_level    0.335
bmi                  0.000
Name: P>|z|, dtype: float64

At a 0.01 significance level, not all the parameters in the model are significantly different from 0. The p-value of the coefficient of avg_glucose_level is greater than 0.005 which means that it is not significant. We compare the values to 0.005 since we are conducting a two-tailed test, meaning we compare with alpha/2 on both ends of the distribution. I chose a 0.01 significance level because we have a lot of data and wanted to be very confident (99% confidence level) when conducting the significance tests of our model's parameters. The age and bmi, however, are statistically significant in our logistic model. We guessed age would be significant but it was interesting to see the results of the other two features.

Machine Learning

Now we have reached the last step. Machine learning is the process of building systems which learn by us feeding it lots of data. Generally, the more data we give it the better. Our first step is to change the categorical features into numerical ones so we can feed it into the model. We will use one-hot encoding which places a 1 for the category in the feature each instance falls under and a 0 everywhere else. However, we have to avoid falling into the dummy variable trap and delete one of the columns in each category since the machine will take the combination of all 0's as an extra category. This creates multicollinearity which basically means two features are too strongly correlated which has a negative effect on the model. To avoid this, we can just drop the first category in each feature and have the machine represent that category with all 0's.

In [21]:
data_with_dummies = pd.get_dummies(data, columns=categoricals[:-1], drop_first=True)
data_with_dummies.head()
Out[21]:
age avg_glucose_level bmi stroke gender_Male gender_Other hypertension_1 heart_disease_1 ever_married_Yes work_type_Never_worked work_type_Private work_type_Self-employed work_type_children Residence_type_Urban smoking_status_formerly smoked smoking_status_never smoked smoking_status_smokes
0 67.0 228.69 36.6 1 1 0 0 1 1 0 1 0 0 1 1 0 0
1 61.0 202.21 28.1 1 0 0 0 0 1 0 0 1 0 0 0 1 0
2 80.0 105.92 32.5 1 1 0 0 1 1 0 1 0 0 0 0 1 0
3 49.0 171.23 34.4 1 0 0 0 0 1 0 1 0 0 1 0 0 1
4 79.0 174.12 24.0 1 0 0 1 0 1 0 0 1 0 0 0 1 0

Our next step is to create train and test sets. We feed the train set to our model to learn from it and test its accuracy on the test set. If we use all our data to train the model, we won't have any data to test our model with! A good size for test set is 20% of our total data, leaving 80% for the model to learn from.

In [22]:
from sklearn.model_selection import train_test_split

# Purely randomized way of creating train and test sets
train_set, test_set = train_test_split(data_with_dummies, test_size=0.2, random_state=3)
len(train_set), len(test_set)
Out[22]:
(4088, 1022)
In [23]:
train_set.head()
Out[23]:
age avg_glucose_level bmi stroke gender_Male gender_Other hypertension_1 heart_disease_1 ever_married_Yes work_type_Never_worked work_type_Private work_type_Self-employed work_type_children Residence_type_Urban smoking_status_formerly smoked smoking_status_never smoked smoking_status_smokes
2804 69.0 70.98 30.0 0 0 0 0 0 1 0 0 0 0 0 0 0 0
1054 76.0 77.52 40.9 0 0 0 0 0 0 0 0 1 0 1 1 0 0
3549 80.0 125.89 28.9 0 0 0 1 0 1 0 1 0 0 1 0 0 1
1678 79.0 80.57 23.8 0 0 0 0 0 1 0 0 1 0 1 0 1 0
2209 78.0 75.19 27.6 0 1 0 1 0 1 0 0 1 0 0 0 1 0
In [24]:
test_set.head()
Out[24]:
age avg_glucose_level bmi stroke gender_Male gender_Other hypertension_1 heart_disease_1 ever_married_Yes work_type_Never_worked work_type_Private work_type_Self-employed work_type_children Residence_type_Urban smoking_status_formerly smoked smoking_status_never smoked smoking_status_smokes
2778 29.0 116.98 23.4 0 0 0 0 0 1 0 1 0 0 1 0 1 0
4029 58.0 101.96 34.5 0 1 0 0 0 1 0 0 0 0 1 0 1 0
261 37.0 162.96 39.4 0 0 0 0 0 1 0 1 0 0 0 0 1 0
1868 49.0 90.58 23.2 0 0 0 0 0 1 0 0 0 0 1 0 0 0
1028 57.0 78.46 32.6 0 0 0 0 0 1 0 1 0 0 1 0 1 0

As we said earlier, we will be using a Logistic Regression model:

In [25]:
from sklearn.linear_model import LogisticRegression

train_X = train_set.drop(columns=['stroke'])
train_y = train_set.stroke

log_model = LogisticRegression(random_state=3, max_iter=500)
log_model.fit(train_X, train_y)

train_predictions = log_model.predict(train_X)
log_model.score(train_X, train_y)
Out[25]:
0.950587084148728
  • Our model had a 95% accuracy! Now let's test it on the test set:
In [26]:
test_X = test_set.drop(columns=['stroke'])
test_y = test_set.stroke

test_predictions = log_model.predict(test_X)
log_model.score(test_X, test_y)
Out[26]:
0.9559686888454012
  • We got a 95.6% accuracy! This is good because we avoided overfitting, which means we trained the model in too much detail for the train set and it ends up doing bad on the test set.

Now let's get the residuals, meaning the errors between the dataset and predictions, of the whole dataset and create a countplot for it.

In [27]:
all_X = data_with_dummies.drop(columns=['stroke'])
all_y = data_with_dummies.stroke

all_predictions = log_model.predict(all_X)
log_model.score(all_X, all_y)
Out[27]:
0.9516634050880626
In [28]:
# Column for output using logistic model
data_with_dummies['prediction'] = pd.Series(all_predictions.flatten())

# Column for the corresponding residuals
data_with_dummies['residual'] = data_with_dummies['stroke'] - data_with_dummies['prediction']
data_with_dummies.head()
Out[28]:
age avg_glucose_level bmi stroke gender_Male gender_Other hypertension_1 heart_disease_1 ever_married_Yes work_type_Never_worked work_type_Private work_type_Self-employed work_type_children Residence_type_Urban smoking_status_formerly smoked smoking_status_never smoked smoking_status_smokes prediction residual
0 67.0 228.69 36.6 1 1 0 0 1 1 0 1 0 0 1 1 0 0 0 1
1 61.0 202.21 28.1 1 0 0 0 0 1 0 0 1 0 0 0 1 0 0 1
2 80.0 105.92 32.5 1 1 0 0 1 1 0 1 0 0 0 0 1 0 0 1
3 49.0 171.23 34.4 1 0 0 0 0 1 0 1 0 0 1 0 0 1 0 1
4 79.0 174.12 24.0 1 0 0 1 0 1 0 0 1 0 0 0 1 0 0 1
In [29]:
plt.figure(figsize=(12, 10))
ax = plt.subplot(111)
sns.countplot(data=data_with_dummies, x='residual')

# Showing the percentages
for p in ax.patches:
    x=p.get_bbox().get_points()[:,0]
    y=p.get_bbox().get_points()[1,1]
    ax.annotate('{:.1f}%'.format(100.*y/len(data)), (x.mean(), y), 
                ha='center', va='bottom')

plt.show()

Improving The Model

If the dataset is very large (especially relative to the number of attributes), it is generally fine to use randomized sampling. However, there are many times where we want the sample data to be representative of the population to avoid sampling bias. This process of sampling is called stratified sampling:

  • stratified sampling: the population is divided into homogeneous subgroups called strata, and the right number of instances are sampled from each stratum to guarantee that the test set is representative of the overall population.

We will divide the samples based on the proportion of strokes, so we can end up with representative train and test sets for patients who had a stroke and did not have a stroke.

In [30]:
from sklearn.model_selection import StratifiedShuffleSplit

# Creating train and test sets using stratified sampling
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=3) # object we can use to split our data

# This loop will run 'n_splits' times because that's how many splits we want
for train_index, test_index in splitter.split(data_with_dummies, data_with_dummies['stroke']):
    strat_train_set = data_with_dummies.loc[train_index]
    strat_test_set = data_with_dummies.loc[test_index]
In [31]:
# Proportions of stroke in the test set
strat_test_set['stroke'].value_counts() / len(strat_test_set)
Out[31]:
0    0.951076
1    0.048924
Name: stroke, dtype: float64
In [32]:
# Proportions of stroke in the overall dataset
data_with_dummies['stroke'].value_counts() / len(data_with_dummies)
Out[32]:
0    0.951272
1    0.048728
Name: stroke, dtype: float64
  • As we can see from above, the proportions of each stratum in the test set and overall data are almost identical. This tells us that we have a well-represented test set which is exactly what we want.

Now let's train and test a model using the new train and test sets and observe the accuracy.

In [33]:
strat_train_X = strat_train_set.drop(columns=['stroke'])
strat_train_y = strat_train_set.stroke

strat_log_model = LogisticRegression(random_state=3, max_iter=500)
strat_log_model.fit(strat_train_X, strat_train_y)

strat_train_predictions = strat_log_model.predict(strat_train_X)
strat_log_model.score(strat_train_X, strat_train_y)
Out[33]:
0.9995107632093934
  • Our new model had a 99.95% accuracy! Now let's test it on the test set:
In [34]:
strat_test_X = strat_test_set.drop(columns=['stroke'])
strat_test_y = strat_test_set.stroke

strat_test_predictions = strat_log_model.predict(strat_test_X)
strat_log_model.score(strat_test_X, strat_test_y)
Out[34]:
0.9990215264187867
  • We got a 99.9% accuracy! This is good because we again avoided overfitting.

Now let's get the new residuals for the whole dataset and create a countplot for it.

In [35]:
strat_all_X = data_with_dummies.drop(columns=['stroke'])
strat_all_y = data_with_dummies.stroke

strat_all_predictions = strat_log_model.predict(strat_all_X)
strat_log_model.score(strat_all_X, strat_all_y)
Out[35]:
0.999412915851272
In [36]:
# Column for output using logistic model with stratified sampling
data_with_dummies['strat_prediction'] = pd.Series(strat_all_predictions.flatten())

# Column for the new corresponding residuals
data_with_dummies['new_residual'] = data_with_dummies['stroke'] - data_with_dummies['strat_prediction']
data_with_dummies.head()
Out[36]:
age avg_glucose_level bmi stroke gender_Male gender_Other hypertension_1 heart_disease_1 ever_married_Yes work_type_Never_worked ... work_type_Self-employed work_type_children Residence_type_Urban smoking_status_formerly smoked smoking_status_never smoked smoking_status_smokes prediction residual strat_prediction new_residual
0 67.0 228.69 36.6 1 1 0 0 1 1 0 ... 0 0 1 1 0 0 0 1 1 0
1 61.0 202.21 28.1 1 0 0 0 0 1 0 ... 1 0 0 0 1 0 0 1 1 0
2 80.0 105.92 32.5 1 1 0 0 1 1 0 ... 0 0 0 0 1 0 0 1 1 0
3 49.0 171.23 34.4 1 0 0 0 0 1 0 ... 0 0 1 0 0 1 0 1 1 0
4 79.0 174.12 24.0 1 0 0 1 0 1 0 ... 1 0 0 0 1 0 0 1 1 0

5 rows × 21 columns

In [37]:
plt.figure(figsize=(12, 10))
ax = plt.subplot(111)
sns.countplot(data=data_with_dummies, x='new_residual')

# Showing the percentages
for p in ax.patches:
    x=p.get_bbox().get_points()[:,0]
    y=p.get_bbox().get_points()[1,1]
    ax.annotate('{:.1f}%'.format(100.*y/len(data)), (x.mean(), y), 
                ha='center', va='bottom')

plt.show()

Conclusion

In this tutorial, we learned how to get data, analyze it, and perform hypothesis testing and machine learning on it. We gained many insights and observed many interesting patterns, and at the end, we built a model to predict strokes with a 99.9% accuracy!