Building a multiclass classification model
Data cleaning, adding structures to PubChem data, building a multiclass model, dealing with imbalanced data
- Introduction
- Read and clean the PubChem data
- Add chemical structures to the PubChem data
- Calculate molecular descriptors
- Split the data into training and test sets
- Create and evaluate a machine learning model
- Use oversampling to compensate for imbalanced data
- Comparing the standard and oversampled models
Introduction
At the Fall 2021 ACS Meeting, the group from NCATS described a number of ADME models they had developed. Even better, the NCATS team also released some of the data used to build these models. In this notebook we'll use the data from the NCATS CYP3A4 assay to classify molecules as CYP3A4 activators, inhibitors, or inactive.
In order to run this notebook, the following Python libraries should be installed
- pandas - handling data tables
- pubchempy - grabbing chemical structures from PubChem
- tqdm - progress bars
- numpy - linear algebra and matrices
- itertools - advanced list handling
- sklearn - machine learning
- lightgbm - gradient boosted trees for machine learning
- matplotlib - plotting
- seaborn - even better plotting
- pingouin - stats
- imbalanced-learning - machine learning with imbalanced datasets
import pandas as pd
import pubchempy as pcp
from tqdm.auto import tqdm
import numpy as np
import itertools
from lib.descriptor_gen import DescriptorGen
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from lightgbm import LGBMClassifier
from sklearn.metrics import plot_confusion_matrix, matthews_corrcoef, roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns
Enable Pandas progress_apply so that we can get progress bars for Pandas operations
tqdm.pandas()
df = pd.read_csv("data/AID_1645841_datatable.csv",skiprows=[1,2,3,4,5],low_memory=False)
df
Note that the data file doesn't have chemical structures as SMILES strings. We're going to add those using the pubchempy library, which can look up the chemical structure based on the PUBCHEM_CID field in our dataframe. This is great, but the pubchempy service will have problems if we pass it a null value. Let's check and see if we have any null values in our PUBCHEM_CID column.
sum(df.PUBCHEM_CID.isna())
We have four null values. Let's look at those rows.
df[df.PUBCHEM_CID.isna()]
We can drop the four rows where PUBCHEM_CID is null.
df.dropna(subset=["PUBCHEM_CID"],inplace=True)
df
In order to lookup a structure based on PUBCHEM_CID the PUBCHEM_CID field must be an integer. Let's take a look at the datatypes for our dataframe.
df.dtypes
The PUBCHEM_CID field is a float64, which is not what we want. Let's convert that column to an integer.
df.PUBCHEM_CID = df.PUBCHEM_CID.astype(int)
The field we want to model is Phenotype-Replicate_1 which takes one of three values. Let's look at the possible values and their distribution. We can see that the class Activator is somewhat underrepresented. We'll start by building a model with the data as provided. Once we've done this, we'll also take a look at whether we can improve our model by employing strategies to compensate for the data imbalance.
df['Phenotype-Replicate_1'].value_counts(normalize=True)
In order to build our model, we need the column were predicting to be represented as numeric values. We can convert the text labels to numbers using the LabelEncoder method from skikit-learn.
labels = df['Phenotype-Replicate_1'].unique().tolist()
le = LabelEncoder()
le.fit(labels)
df['label'] = le.transform(df['Phenotype-Replicate_1'])
Our dataframe has a bunch of extra fields that are not necessary for this analysis. Let's simplify and create a new dataframe with only fields we care about.
data_df = df[['PUBCHEM_CID','Phenotype-Replicate_1','label']].copy()
data_df
cmpd_list = []
num_chunks = len(df)/100
for chunk in tqdm(np.array_split(data_df.PUBCHEM_CID,num_chunks)):
cmpd_list.append(pcp.get_compounds(chunk.tolist()))
We collected the chemical structures in a list of lists. We need to flatten this into a single list. The operation works something like this
[[1,2,3],[4,5,6],[7,8,9]] -> [1,2,3,4,5,6,7,8,9]
data_df['Compound'] = list(itertools.chain(*cmpd_list))
Extract the SMILES from the Compound objects in the Compound column.
data_df['SMILES'] = [x.canonical_smiles for x in data_df.Compound]
desc_gen = DescriptorGen()
Add the descriptors to the dataframe.
data_df['desc'] = data_df.SMILES.progress_apply(desc_gen.from_smiles)
train, test = train_test_split(data_df)
train_X = np.stack(train.desc)
train_y = train.label
test_X = np.stack(test.desc)
test_y = test.label
Create and evaluate a machine learning model
- Intstantiate a LighGBM classifier
- Train the model
- Predict on the test set
lgbm = LGBMClassifier()
lgbm.fit(train_X, train_y)
pred = lgbm.predict_proba(test_X)
Evaluate model performance using the ROC AUC score. Note that this is a little different with a multiclass classifer. We specify class='ovo' which means that we are evaluating "one vs one". We evaluate the AUC for all pairs of classes. The argument average='macro' indicates that the reported AUC is the average of all of the one vs one comparisons.
roc_auc_score(test_y,pred,multi_class='ovo',average='macro')
We can also plot a confusion matrix to examine the model's performance on each of the three classes.
sns.set_style("white")
sns.set_context('talk')
plt.rcParams["figure.figsize"] = (8,8)
plot_confusion_matrix(lgbm,test_X,test_y,display_labels=sorted(labels),cmap=plt.cm.Blues)
from imblearn.over_sampling import RandomOverSampler
We will create an oversampling object and use it to resample our training set.
ros = RandomOverSampler()
resample_X, resample_y = ros.fit_resample(train_X,train_y)
Recall that our original training set is somewhat imbalanced. The minority class (0 or Activator) only accounts for ~7% of the data.
pd.Series(train_y).value_counts()
After oversampling the dataset is balanced.
pd.Series(resample_y).value_counts()
Build a model with the balanced, oversampled data
resample_lgbm = LGBMClassifier()
resample_lgbm.fit(resample_X, resample_y)
Make a prediction with the new model, built with the resampled data.
resample_pred = resample_lgbm.predict_proba(test_X)
roc_auc_score(test_y,resample_pred,multi_class='ovr',average='macro')
As above, we can plot a confusion matrix to examine the performance of the classifier trained on the oversampled data. Let's put the two confusion matrices side by side to compare.
sns.set_style("white")
sns.set_context('talk')
classifiers = [lgbm,resample_lgbm]
titles = ["Standard","Oversampled"]
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15,5))
for cls,ax,title in zip(classifiers, axes, titles):
plot_confusion_matrix(cls,test_X,test_y,display_labels=sorted(labels),cmap=plt.cm.Blues,ax=ax)
ax.title.set_text(title)
plt.tight_layout()
res = []
for i in tqdm(range(0,10)):
# split the data into training and test sets
train, test = train_test_split(data_df)
train_X = np.stack(train.desc)
train_y = train.label
test_X = np.stack(test.desc)
test_y = test.label
# Create the standard model
lgbm = LGBMClassifier()
lgbm.fit(train_X,train_y)
pred = lgbm.predict_proba(test_X)
auc = roc_auc_score(test_y,pred,multi_class='ovo',average='macro')
# Create the resampled model
resample_lgbm = LGBMClassifier()
resample_X, resample_y = ros.fit_resample(train_X,train_y)
resample_lgbm.fit(resample_X, resample_y)
resample_pred = resample_lgbm.predict_proba(test_X)
resample_auc = roc_auc_score(test_y,resample_pred,multi_class='ovr',average='macro')
res.append([auc, resample_auc])
Create a dataframe with the AUC values using the Standard and Oversampled models
res = np.array(res)
res_df = pd.DataFrame(res,columns=["Standard","Oversampled"])
res_df.head()
Reformat the dataframe to combine the two columns in res_df
melt_df = res_df.melt()
melt_df.columns = ["Method","AUC"]
melt_df
Plot the AUC distributions for the Standard and Oversampled models as a kernel density estimate (KDE)
sns.set(rc={'figure.figsize': (10, 10)})
sns.set_context('talk')
sns.kdeplot(x="AUC",hue="Method",data=melt_df);
Another way of comparing the distributions is to use the plot_paired method available in the pingouin library. Note that the AUC for the Resampled method is always greater than that for the Standard method.
from pingouin import wilcoxon, plot_paired
melt_df['cycle'] = list(range(0,10))+list(range(0,10))
plot_paired(data=melt_df,dv="AUC",within="Method",subject="cycle");
In order to compare distributions, we sometime perform a t-test. However, a t-test assumes that the data is normally distributed. Since we can't make this assumption, we can use the Wilcoxon ranked sum test, which is the non-parametric equivalent to the t-test. The pingouin library provides a convenient implementation in the wilcoxon function. As we can see from the p-value in the table below, the difference in the means of the distributions is statistically significant.
wilcoxon(res_df.Standard,res_df.Oversampled)