Keras callbacks guide and code

I decided to look into Keras callbacks.

When you want to do some tasks every time a training/epoch/batch, that’s when you need to define your own callback. It’s simple, it’s just I needed to look into the code to know what I could do with it. In my case, I wanted to compute an auc_roc score after training every epoch. It was being computed after out of fit function as I am using multiple hdf files to do it. Let’s fix it.

Your own callback function can be defined as below:

import keras

class My_Callback(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        return

    def on_train_end(self, logs={}):
        return

    def on_epoch_begin(self, logs={}):
        return

    def on_epoch_end(self, epoch, logs={}):
        return

    def on_batch_begin(self, batch, logs={}):
        return

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))
        return

 

The code is quite straightforward. This class is inherited from keras.callbacks.Callback, which already has those on_{train, epoch, batch}_{begin, end} functions. What we need to do is to redefine them. Or overload them. And then put an instance of your callback as an input argument of keras’s model.fit function.

What we can do in each function? First, there are input arguments – epoch/batch, and logs{}. What are they? And what property do we need?

on_train_begin(self, logs={})

See an example on Keras’ documentation.

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []

Yes, this is where we need to define some useful properties. In this example, self.losses is added and initiated with a blank list.

logs

There are not much with the input argument here. I think logs would be just a blank dictionary in most of the cases.


(pdb) print logs
{}

This is because nothing happened yet. Not sure what happens if we save a model and load it again though.

DETOUR! Let’s look into self.params and self.model. This is not only about on_train_begin() but applies for every callback function.

self.params


(Pdb) self.params
{'verbose': 1, 'nb_epoch': 12, 'batch_size': 128, 'metrics': ['loss', 'acc', 'val_loss', 'val_acc'], 'nb_sample': 60000, 'do_validation': True}

self.params have some useful information of the training configuration.

self.model


(Pdb) self.model
<keras.models.Sequential object at 0x1063b8e90>

Oops, seems like we can do something interesting with this in every callback. This is the model instance – an instance of Sequential() or Functional we are training now. We can use lots of information from it.

Check out model page – there are too many properties and functions! A tentative list would be…

self.model.validation_data

(At the moment this property exists not after on_train_begin, but after on_batch_begin)

len(self.model.validation_data) == 3, because validation_data[0] == train_x (that you input in model.fit()), validation_data[1] == train_y, validation_data[2]=sample_weight, as below.


(Pdb) self.model.validation_data[0].shape
(10000, 1, 28, 28)

(Pdb) self.model.validation_data[1].shape
(10000, 10)

(Pdb) self.model.validation_data[2].shape
(10000,)

You can use this data to compute your own metric e.g. computing auc_roc_score with scikit-learn package.

self.model.training_data

It shows up on dir(self.model), but at any point of training I couldn’t get it.

self.model.save_weights

You may want to save weights.

on_epoch_begin(self, epoch, logs={})

When the very first epoch begins,


(Pdb) epoch
0
(Pdb) logs
{}

When the second epoch begins,


(Pdb) epoch
1
(Pdb) logs
{}

I.e. logs are cleared every time, epoch stands for the number of epoch in zero-based indexing.

on_epoch_end(self, epoch, logs={})

When the very first epoch ends,


(Pdb) epoch
0
(Pdb) logs
{'acc': 0.13145000000000001, 'loss': 2.3134536211649577, 'val_acc': 0.16389999999999999, 'val_loss': 2.28033113861084}

When the second epoch ends,


(Pdb) epoch
1
(Pdb) logs
{'acc': 0.15653333332538605, 'loss': 2.255207451756795, 'val_acc': 0.185, 'val_loss': 2.2099738941192628}

So, the logs only contains the result of the current epoch. That’s why you need to append the result by yourself, as below:


def on_epoch_end(self, epoch, logs={}):
    self.losses.append(logs.get('loss'))
    return

on_batch_begin(self, batch, logs={})


(Pdb) batch
0
(Pdb) logs
{'batch': 0, 'size': 128}

batch is, again, batch index, and logs has some information of batch.

on_batch_end(self, batch, logs={})


(Pdb) batch
0
(Pdb) logs
{'acc': array(0.1015625, dtype=float32), 'loss': array(2.366058349609375, dtype=float32), 'batch': 0, 'size': 128}

After training a batch, logs has bit more information.

Example

I’d like to compute auc_roc_score as mentioned for every epoch and somehow store them. As below, where I overload all the functions to make it clear (redundant though).


import keras
from sklearn.metrics import roc_auc_score

class Histories(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.aucs = []
        self.losses = []

    def on_train_end(self, logs={}):
        return

    def on_epoch_begin(self, epoch, logs={}):
        return

    def on_epoch_end(self, epoch, logs={}):
        self.losses.append(logs.get('loss'))
        y_pred = self.model.predict(self.model.validation_data[0])
        self.aucs.append(roc_auc_score(self.model.validation_data[1], y_pred))
        return

    def on_batch_begin(self, batch, logs={}):
        return

    def on_batch_end(self, batch, logs={}):
        return

The whole files are in this repo. Clone and run the mnist_cnn.py.

Advertisements

One thought on “Keras callbacks guide and code

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s