I have written a few simple keras layers. This post will summarise about how to write your own layers. It’s for beginners because I only know simple and easy ones 😉
1. Keras layer
https://keras.io/layers/about-keras-layers/ introduces some common methods. For beginners I don’t think it’s necessary to know these.
2. Keras Lambda layer
Lambda layer is an easy way to customise a layer to do simple arithmetics. As written in the page,
…an arbitrary Theano / TensorFlow expression…
we can use the operations supported by Keras backend such as dot, transpose, max, pow, sign, etc
as well as those are not specified in the backend documents but actually supported by Theano and TensorFlow – e.g., **, /, //, %
for Theano.
2.1 Lambda layer and output_shape
You might need to specify the output shape of your Lambda layer, especially your Keras is on Theano. Otherwise it just seems to infer it with input_shape
.
2.1.1 With function
You can create a function that returns the output shape, probably after taking input_shape as an input. Here, the function returns the shape of the WHOLE BATCH.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def output_of_lambda(input_shape): | |
return (input_shape[0], 1, input_shape[2]) | |
def mean(x): | |
return K.mean(x, axis=1, keepdims=True) | |
model.add(Lambda(mean, output_shape=output_of_lambda)) |
2.1.2 With tuple
If you pass tuple, it should be the shape of ONE DATA SAMPLE.
3. A Keras model as a layer
On high-level, you can combine some layers to design your own layer. For example, I made a Melspectrogram layer as below. (Complete codes are on keras_STFT_layer repo.) In this way, I could re-use Convolution2D
layer in the way I want.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def Melspectrogram(n_dft, input_shape, trainable, n_hop=None, | |
border_mode='same', logamplitude=True, sr=22050, | |
n_mels=128, fmin=0.0, fmax=None, name='melgram'): | |
if input_shape is None: | |
raise RuntimeError('specify input shape') | |
Melgram = Sequential() | |
# Prepare STFT. | |
x, STFT_magnitude = get_spectrogram_tensors(n_dft, | |
n_hop=n_hop, | |
border_mode=border_mode, | |
input_shape=input_shape, | |
logamplitude=False) | |
# output: (None, freq, time) | |
stft_model = Model(input=x, output=STFT_magnitude, name='stft') | |
stft_model = trainable | |
Melgram.add(stft_model) | |
# Convert to a proper 2D representation (ndim=4) | |
if K.image_dim_ordering() == 'th': | |
Melgram.add(Reshape((1,) + stft_model.output_shape[1:], | |
name='reshape_to_2d')) # (None, 1, freq, time) | |
else: | |
Melgram.add(Reshape(stft_model.output_shape[1:] + (1,), | |
name='reshape_to_2d')) # (None, freq, time, 1) | |
# build a Mel filter | |
mel_basis = _mel(sr, n_dft, n_mels, fmin, fmax) # (128, 1025) (mel_bin, n_freq) | |
mel_basis = np.fliplr(mel_basis) # to make it from low-f to high-freq | |
n_freq = mel_basis.shape[1] | |
if K.image_dim_ordering() == 'th': | |
mel_basis = mel_basis[:, np.newaxis, :, np.newaxis] | |
# print('th', mel_basis.shape) | |
else: | |
mel_basis = np.transpose(mel_basis, (1, 0)) | |
mel_basis = mel_basis[:, np.newaxis, np.newaxis, :] | |
# print('tf', mel_basis.shape) | |
stft2mel = Convolution2D(n_mels, n_freq, 1, border_mode='valid', bias=False, | |
name='stft2mel', weights=[mel_basis]) | |
stft2mel.trainable = trainable | |
Melgram.add(stft2mel) #output: (None, 128, 1, 375) if theano. | |
if logamplitude: | |
Melgram.add(Logam_layer()) | |
# i.e. 128ch == 128 mel-bin, for 375 time-step, therefore, | |
if K.image_dim_ordering() == 'th': | |
Melgram.add(Permute((2, 1, 3), name='ch_freq_time')) | |
else: | |
Melgram.add(Permute((1, 3, 2), name='ch_freq_time')) | |
# output dot product of them | |
return Melgram |
Downside would be some overhead due to many layers.
4. Customising Layer
When Lego-ing known layers doesn’t get you what you want, write your own!
4.1 Read the document
https://keras.io/layers/writing-your-own-keras-layers/ Read this! Whether you fully understand it or not. I didn’t fully understand but later I got it thanks to @fchollet’s help.
4.2 Four methods
4.2.1 __init()__
:
initiate the layer. Assign attributes to self
so that you can use them later.
4.2.2build(self, input_shape)
:
- initiate the tensor variables (e.g.
W
,bias
, or whatever) using Keras backend functions (e.g.,self.W = K.variable(an_init_numpy_array)
). - set
self.trainable_weights
with a list of variables. e.g.,self.trainable_weights=[self.W]
.
Remember : trainable weights should be tensor variables so that machine can auto-differenciate them for you.
Remember (2): Check out the dtype of every variable! If you initiate a tensor variable with float64
a numpy array, the variable might be also float64
, which will get you an error. Usually it wouldn’t because by default K.variable()
casts the value into float32
. But, check check check! check it by simply printing x.dtype
.
4.2.3 call(self, x, mask=None)
:
This is where you implement the forward-pass operation. You may want to dot product with one of the trainable weights and input (K.dot(x, self.W)
), wanna expand the dimensionality of a tensor variable (K.expand_dims(var1, dim=2)
), or whatever.
Again, dtype! For example, I had to use this line, np.sqrt(2. * np.pi).astype('float32')
, to make the constant to be float32
.
4.2.4 get_output_shape_for(self, input_shape)
As the name says.
4.3 Examples
4.3.1 Example 1 : Cropping2D Layer
It crops 2D input. Simple!
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Cropping2D(Layer): | |
def __init__(self, cropping=((0, 0), (0, 0)), dim_ordering='default', **kwargs): | |
super(Cropping2D, self).__init__(**kwargs) | |
if dim_ordering == 'default': | |
dim_ordering = K.image_dim_ordering() | |
self.cropping = tuple(cropping) | |
assert len(self.cropping) == 2, 'cropping must be a tuple length of 2' | |
assert len(self.cropping[0]) == 2, 'cropping[0] must be a tuple length of 2' | |
assert len(self.cropping[1]) == 2, 'cropping[1] must be a tuple length of 2' | |
assert dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}' | |
self.dim_ordering = dim_ordering | |
self.input_spec = [InputSpec(ndim=4)] | |
def build(self, input_shape): | |
self.input_spec = [InputSpec(shape=input_shape)] | |
def get_output_shape_for(self, input_shape): | |
if self.dim_ordering == 'th': | |
return (input_shape[0], | |
input_shape[1], | |
input_shape[2] – self.cropping[0][0] – self.cropping[0][1], | |
input_shape[3] – self.cropping[1][0] – self.cropping[1][1]) | |
elif self.dim_ordering == 'tf': | |
return (input_shape[0], | |
input_shape[1] – self.cropping[0][0] – self.cropping[0][1], | |
input_shape[2] – self.cropping[1][0] – self.cropping[1][1], | |
input_shape[3]) | |
else: | |
raise Exception('Invalid dim_ordering: ' + self.dim_ordering) | |
def call(self, x, mask=None): | |
input_shape = self.input_spec[0].shape | |
if self.dim_ordering == 'th': | |
return x[:, | |
:, | |
self.cropping[0][0]:input_shape[2]-self.cropping[0][1], | |
self.cropping[1][0]:input_shape[3]-self.cropping[1][1]] | |
elif self.dim_ordering == 'tf': | |
return x[:, | |
self.cropping[0][0]:input_shape[1]-self.cropping[0][1], | |
self.cropping[1][0]:input_shape[2]-self.cropping[1][1], | |
:] | |
def get_config(self): | |
config = {'cropping': self.cropping} | |
base_config = super(Cropping2D, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) |
4.3.2 Example 2. ParametricMel
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ParametricMel(Layer): | |
def __init__(self, n_mels, n_freqs, sr, scale=24., init='mel', **kwargs): | |
self.supports_masking = True | |
self.scale = scale # scaling | |
self.n_mels = n_mels | |
if init == 'mel': | |
self.means_init = np.array(_mel_frequencies(n_mels, fmin=0.0, fmax=sr/2), dtype='float32') | |
stds = self.means_init[1:] – self.means_init[:-1] | |
self.stds_init = 0.3 * np.hstack((stds[0:1], stds[:])) # 0.3: kinda make sense by the resulting images.. | |
self.center_freqs_init = [float(i)*sr/2/(n_freqs-1) for i in range(n_freqs)] # dft frequencies | |
super(ParametricMel, self).__init__(**kwargs) | |
def build(self, input_shape): | |
self.means = K.variable(self.means_init, | |
name='{}_means'.format(self.name)) | |
self.stds = K.variable(self.stds_init, | |
name='{}_stds'.format(self.name)) | |
self.center_freqs_init = np.array(self.center_freqs_init)[np.newaxis, :] # (1, n_freq) | |
self.center_freqs_init = np.tile(self.center_freqs_init, (self.n_mels, 1)) # (n_mels, n_freq) | |
self.center_freqs = K.variable(self.center_freqs_init, | |
name='{}_center_freqs'.format(self.name)) | |
self.trainable_weights = [self.means, self.stds] # [self.means, self.stds] | |
self.n_freq = input_shape[1] | |
self.n_time = input_shape[2] | |
print '–build–' | |
def get_output_shape_for(self, input_shape): | |
return (input_shape[0], self.n_mels, input_shape[2]) | |
def call(self, x, mask=None): | |
means = K.expand_dims(self.means, dim=1) | |
stds = K.expand_dims(self.stds, dim=1) | |
freq_to_mel = (self.scale * K.exp(-1. * K.square(self.center_freqs – means) \ | |
/ (2. * K.square(stds)))) \ | |
/ (np.sqrt(2. * np.pi).astype('float32') * stds) # (n_mel, n_freq) | |
out = K.dot(freq_to_mel, x) # (n_mel, None, n_time) | |
return K.permute_dimensions(out, (1, 0, 2)) |
4.4 Tips
Remember: you need to make the operation of layer differentiable w.r.t the input and trainable weights you set. Look up keras backend use them.
tensor.shape.eval()
returns an integer tuple. You would need to print them a lot 😉
magnificent issues altogether, you simply received a new reader. What might you suggest about your put up that you just made a few days ago? Any sure?
LikeLike
Your post is great, however, I wish you had provided some explanations on your examples. it is hard to track!
LikeLike