TensorFlow.js Layers: High-Level Machine Learning Model API
A part of the TensorFlow.js ecosystem, TensorFlow.js Layers is a high-level
API built on TensorFlow.js Core,
enabling users to build, train and execute deep learning models in the browser.
TensorFlow.js Layers is modeled after
Keras and
tf.keras and can
load models saved from those libraries.
Importing
There are three ways to import TensorFlow.js Layers
- You can access TensorFlow.js Layers through the union package
between the TensorFlow.js Core and Layers:
@tensorflow/tfjs
- You can get [TensorFlow.js] Layers as a module:
@tensorflow/tfjs-layers.
Note that
tfjs-layers
has peer dependency on tfjs-core, so if you import
@tensorflow/tfjs-layers
, you also need to import
@tensorflow/tfjs-core
. - As a standalone through unpkg.
Option 1 is the most convenient, but leads to a larger bundle size (we will be
adding more packages to it in the future). Use option 2 if you care about bundle
size.
Getting started
Building, training and executing a model
The following example shows how to build a toy model with only one dense
layer
to perform linear regression.
import * as tf from '@tensorflow/tfjs';
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
model.compile({loss: 'meanSquaredError', optimizer: 'SGD'});
const xs = tf.tensor2d([[1], [2], [3], [4]], [4, 1]);
const ys = tf.tensor2d([[1], [3], [5], [7]], [4, 1]);
await model.fit(xs, ys, {epochs: 500});
const output = model.predict(tf.tensor2d([[5]], [1, 1]));
output.print();
Loading a pretrained Keras model
You can also load a model previously trained and saved from elsewhere (e.g.,
from Python Keras) and use it for inference or transfer learning in the browser.
For example, in Python, save your Keras model using
tensorflowjs,
which can be installed using pip install tensorflowjs
.
import tensorflowjs as tfjs
tfjs.converters.save_keras_model(model, '/path/to/tfjs_artifacts/')
To load the model with TensorFlow.js Layers:
import * as tf from '@tensorflow/tfjs';
const model = await tf.loadModel('http://foo.bar/tfjs_artifacts/model.json');
For more information