Skin diseases prediction in the browser using TF.js

Melanoma is a type of skin cancer that does not occur most often, but it is the most common cause of death in humans.

Melanoma is successfully treated in most cases in the early stages but has a very poor prognosis if it is detected in an advanced form.

Melanoma’s early-stage noticing is really crucial, so why not build a simple ML app for the browser to classify moles by photo?

In 2018 were announced a Tensorflow.js framework to train and execute ML models in the Browser/Node.js environment.

There are many pre-trained models available for TF.js, also there is a converter, so we can convert pre-trained models from other frameworks and run them on the browser. But we’ll try to train our model from scratch.

The first step needed to train a model is to find a good dataset. Fortunately, there is a good dataset on Kaggle, so after downloading it we can start train our model.

Dataset has a CSV with pixels of 28x28 images and classified disease. That’s actually what we need to train an ML model.

There are lots of resources nowadays to learn ML. My learning path was:

  1. The first intro for me was the book Machine Learning in Action. A book describes basic ML concepts, tasks that ML solves, basic algorithms (k-nearest neighbors, decision trees, linear regression, SVM, etc).
  2. Then I watched a “Deep Learning Fundamentals with Keras” course. This is a pretty short but informative course about how Deep Neural Networks (the accent made on artificial and convolutional neural networks).
  3. I still felt a knowledge gap so passed also the “Deep Learning A-Z course” on Udemy. This pretty deep course teaches how to train ANN, CNN, RNN, build AutoEncoders, etc. Really liked this course.
  4. Please note, I am still not an AI expert: I am practicing in training small models, continue to read about ML architectures, statists, etc.

Exploratory analysis, balancing dataset

Let's start with Exploratory Analysis and see how our data looks like, then let's prepare a test and training set.

For analysis, we’ll use Jupyter notebook and Python3 with Pandas and Numpy libraries to work with data. Also, we’ll use Matplotlib and Seaborn to draw visualizations.

Let’s import our CSV into the Jupyter notebook using pandas.

Now we can show diseases distribution:

As you can see we have a very imbalanced dataset, most of our data are moles so just by predicting everything as a mole model accuracy wilL be high.

There are a few strategies for fixing balanced datasets, for example:

  1. Oversampling
  2. Undersampling
  3. Generating synthetic data

I used oversampling strategy, which actually consists of replicating some minority cases to balance the dataset. To oversample the dataset we can use the imblearn package:

Now we can split data into test and training sets, using the sci-kit train_test_split method

The dataset looks balanced now, we also split data into train and test sets. So we can start training our model and predict results.

Training Model using TF.js

The goal is to build image-based classification and a simple CNN is a default choice as neural network architecture.

The architecture of the model consists of mixed Convolutional/MaxPooling layers, a fully connected layer, a 1 hidden layer, and the output layer to classify 7 diseases provided in the dataset. The activation function for hidden layers is ReLU, while for the output layer the softmax was chosen( which is often used for multi-categorical classification).

Here is a nice tool to visualize, how Convolutional/Pooling Layers work in CNN networks.

The first convolutional layer has a 28x28x3 input shape because our training images have 28x28 resolution and 3 channels (RGB pictures). All convolutional layers have 32 filters and a kernel size of 3x3. Max Pool layers size is 2x2.

As a loss function, categorical cross-entropy was chosen which is often chosen for multi-categorical classification. Adam was chosen as the optimizer and accuracy were chosen as a base metric.

To train the model we need to prepare data, build a model, and finally train the model on prepared data. Below is a basic code to build a classifier.

Now, let’s prepare data for training. We are receiving points and disease label from CSV, transform each row:

  1. reshape points to the tensor of shape 28x28x3 (remember we have 28x28 RGB images).
  2. apply one-hot encoding for labels

The last step here is to group elements into batches using the TF.js batch method.

Now we can launch training for 20 epochs (value chosen empirically), save our model, and evaluate the model on a test set.

After some time we see that our model has 96% accuracy on the training set, and 95% accuracy on the test set. Looks pretty good!

It’s time to write a simple predictor and test our model on concrete images. The algorithm is simple: load saved model, load and resize the image, normalize its pixels, reshape to 28x28x3 tensor and pass it into TF.js model.predict method.

Below is a function to load/resize/normalize image pixels:

And here the code for predict results:

Let’s check the results:

It looks like the predictor looks well, so let’s write an app in React.js and TF.js to predict diseases in the browser using the user’s camera

Building browser app with React.js

We already have predictor running on Node environment, now we should just make it run in the browser.

So let’s create a React app with npx create-react-app skinscan command.

Let’s put a saved model into the “/public/model” folder, then in the code set TF backend to WebGL using tf.setBackend method and load model using loadLayersModel.

The camera initialization and passing stream into video element briefly look like this:

Finally, on image capture, let’s predict the result similarly as we did on the Node environment: crop(I used 104px centered square as crop area) and resize the picture, then get image pixels using tf.browser.fromPixel method and normalize them, and pass them as an argument into the model.predict method

Demo Time

A fully-featured app is available here:

I’d suggest trying it on the mobile because it’s much easier to take good quality photos on the phone :).

Feel free to play with it, everything happens in your browser and I don’t store anything (until you send the contact form).


software engineer