We will train a super simple neural network as a toy for the remainder of the article. However, you can replace the model we use here with any model you like!
We are training a linear regression model with one linear unit in a single dense layer and without an activation function. For training, we will use data from a line with a slope of 2 and an intercept of 5.
- The code below is prepared to be run in a jupyter notebook. I would suggest you use Google Colaboratory. You can go directly to a preset Colaboratory notebook by clicking this link.
!pip install --quiet tensorflowjs
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import tensorflowjs as tfjs
# Set seed for reproducibility
tf.random.set_seed(1)
# Setup toy data
X = tf.range(-5, 5, 0.1)
Y = 2 * X + 5
# Create a simple model with 1 Linear unit in 1 Dense layer
model = Sequential([
Dense(units=1, input_shape=[1])
])
# Use stochastic gradient descent and mean squared error
model.compile(optimizer='sgd', loss='mse')
# Fit on the toy data
model.fit(X, Y, epochs=100, batch_size=8, verbose=0)
# Sanity check that the model has successfully trained
slope, intercept = [w.numpy().squeeze() for w in model.weights]
print("True Model: 2.0x + 5.0")
print(f"Fit Model: {slope:.1f}x + {intercept:.1f}")
# Save the model
MODEL_DIR = './js_model'
tfjs.converters.save_keras_model(model, MODEL_DIR)
Once you run the cell with the code you should get the following output.
[...]
True Model: 2.0x + 5.0
Fit Model: 2.0x + 5.0
[...]
If that is the case, then our toy model has trained successfully! However, the final couple of lines, where we save the model, are the most important in this example.
# Save the model
MODEL_DIR = './js_model'
tfjs.converters.save_keras_model(model, MODEL_DIR)
This saves your model to a Tensorflow.js readable format which you can use in a Lambda function. There you should be able to find all of the necessary files in the ./js_model directory. There will be two types of files.
- model.json – this file contains the information necessary to rebuild your model.
- shards – these contain the data about the weights of your model, and there can be multiple shards.
Download these files. We will use them to build the AWS pipeline.