Tensorflow.js Federated Learning Server
This library sets up a simple socket.io-based server for transmitting and receiving
TensorflowJS model weights.
Usage
Basic
import * as http from 'http';
import * as federated from 'federated-learning-server';
const INIT_MODEL = 'file:///initial/model.json';
const httpServer = http.createServer();
const fedServer = new federated.Server(httpServer, INIT_MODEL);
fedServer.setup().then(() => {
httpServer.listen(8080);
});
Setting the Initial Model
new federated.Server(httpServer, tfModel);
new federated.Server(httpServer, 'https://remote.server/tf-model.json');
new federated.Server(httpServer, 'file:///my/local/file/tf-model.json');
new federated.Server(httpServer, async () => {
const model = await tf.loadModel('file:///transfer/learning/model.json');
model.layers[0].trainable = false;
return model;
});
new federated.Server(httpServer, federatedServerModel);
The simplest way to set up a federated.Server
is to pass a tf.Model
. However, you can also pass a string that will be delegated to tf.loadModel
(both https?://
and file://
URLs should work), or an asynchronous function that will return a tf.Model
. The final option is to define your own FederatedServerModel
, which has to implement various saving and loading methods. See its documentation for more details.
Note that by default, different tf.Model
versions will be saved as files in subfolders of ${process.cwd()}/saved-models/
. If you would like to change this directory, you can pass a modelDir
configuration parameter, e.g. federated.Server(httpServer, model, { modelDir: '/mnt/my-vfs' })
.
If you would like to skip the persistence layer, you can instead import FederatedServerInMemoryModel
which will update a single model in memory. Furthermore, if you want a version of this library which omits socket.io in favor of a mocked-out version that works in the browser, check out the mock server library.
Setting Hyperparameters
new federated.Server(httpServer, model, {
serverHyperparams: {
aggregation: 'mean',
minUpdatesPerVersion: 20,
},
clientHyperparams: {
learningRate: 0.01,
epochs: 5,
examplesPerUpdate: 10,
batchSize: 5,
noiseStddev: 0.001
},
verbose: false,
modelDir: '/mnt/my-vfs',
modelCompileConfig: {
loss: 'categoricalCrossEntropy',
metrics: ['accuracy']
}
})
Many of these hyperparameters matter a great deal for the efficiency and privacy of learning, but the correct settings depend greatly on the nature of the data, the size of the model being trained, and how consistently the data is distributed across clients. In the future, we hope to support automated (and dynamic) tuning of these hyperparameters.
Listening to Events
You can add an event listener that fires each time a client uploads a new set of weights (and optionally, self-reported metrics of how well the model performed on the examples used in training):
fedServer.onUpload(message => {
console.log(message.model.version);
console.log(message.model.vars);
console.log(message.clientId);
console.log(message.metrics);
});
You can also listen for whenever the server computes a new version of the model:
fedServer.onNewVersion((oldVersion, newVersion) => {
console.log(`updated model from ${oldVersion} to ${newVersion}`);
});
TODO
Robustness:
median
and trimmed-mean
aggregations (for Byzantine-robustness)- client authentication (e.g. google oauth, captchas)
- smoothing to limit individual clients' weight contributions (to prevent model from overfitting to most active clients and also create preconditions for Byzantine-robust learning if some clients are adversarial)
- create virtual server-side clients who minimize train loss
- discard client updates that increase server-side train loss (or subtract updates' projections onto the direction of increasing train loss)
Privacy:
- determine how to set hyperparameters such that each version of the model is differentially private to individual clients' updates (i.e. prevent sensitive information from leaking from client->server->client)
- consider implementing secure aggregation (i.e. prevent sensitive information from leaking from client->server)