BranchKey Python Client


Official Python client for the BranchKey federated learning platform. This library provides a simple interface to upload model weights, download aggregated results, and track training runs.
Installation
pip install branchkey
Requirements: Python 3.9 or higher
Quick Start
1. Get Credentials
Create a leaf entity through the BranchKey platform to obtain credentials via the /v2/entities API endpoint:
credentials = {
"id": "your-leaf-uuid",
"name": "my-client",
"session_token": "your-session-token-uuid",
"owner_id": "your-user-uuid",
"tree_id": "your-tree-uuid",
"branch_id": "your-branch-uuid"
}
2. Initialize Client
from branchkey.client import Client
client = Client(credentials, host="https://app.branchkey.com")
3. Upload Model Weights
import numpy as np
weighting = 1000
parameters = [layer1_weights, layer2_weights, ...]
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
print(f"Uploaded: {file_id}")
4. Download Aggregated Results
if not client.queue.empty():
aggregation_id = client.queue.get(block=False)
client.file_download(aggregation_id)
Configuration Options
client = Client(
credentials,
host="https://app.branchkey.com",
rbmq_host=None,
rbmq_port=5672,
ssl=True,
wait_for_run=False,
run_check_interval_s=30,
proxies=None
)
Model Weight Format
Model weights are stored in compressed NPZ format.
Structure
weighting = 1000
parameters = [layer1, layer2, ...]
Weighting Options
The weighting parameter controls how much influence this update has during aggregation:
1. By Sample Count (Most Common)
weighting = len(train_dataset)
2. Equal Weighting
weighting = 1
3. Quality-Based Weighting
validation_accuracy = 0.85
weighting = len(train_dataset) * validation_accuracy
4. Manual Weighting
weighting = 5.0
PyTorch Example
import numpy as np
weighting = len(train_dataset)
parameters = []
for name, param in model.named_parameters():
parameters.append(param.data.cpu().detach().numpy())
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
weighting, parameters = client.convert_pytorch_numpy(
model.named_parameters(),
weighting=len(train_dataset)
)
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
TensorFlow/Keras Example
import numpy as np
weighting = len(train_dataset)
parameters = [layer.numpy() for layer in model.trainable_weights]
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
Manual NPZ Creation
import numpy as np
arrays_dict = {'weighting': np.array([weighting], dtype=np.float64)}
for i, arr in enumerate(parameters):
arrays_dict[f'layer_{i}'] = arr
np.savez_compressed("model_weights.npz", **arrays_dict)
Loading Aggregated Weights
npz_data = np.load("aggregated_files/aggregation_id.npz")
layer_keys = sorted([k for k in npz_data.files if k.startswith('layer_')])
parameters = [npz_data[k] for k in layer_keys]
for i, param in enumerate(model.parameters()):
param.data = torch.from_numpy(parameters[i])
Example NPZ File Contents
Client Upload Format (with weighting):
>>> npz_data.files
['weighting', 'layer_0', 'layer_1', 'layer_2', 'layer_3', ...]
>>> npz_data['weighting']
array([1530.])
>>> npz_data['layer_0'].shape, npz_data['layer_0'].dtype
((32, 1, 5, 5), dtype('float32'))
>>> npz_data['layer_0'][:1, :2, :2, :]
array([[[[-0.18576819, -0.03041792, 0.19532707, -0.11234483, -0.01512307],
[ 0.19993757, -0.06492048, 0.08324468, -0.19899307, -0.0412709 ]]]],
dtype=float32)
Aggregated Result Format (layers only):
>>> npz_data.files
['layer_0', 'layer_1', 'layer_2', ...]
Performance Metrics
Submit training or testing metrics:
import json
metrics = {"accuracy": 0.95, "loss": 0.12}
client.send_performance_metrics(
aggregation_id="aggregation-uuid",
data=json.dumps(metrics),
mode="test"
)
Client Properties
client.run_status
client.run_number
client.leaf_id
client.branch_id
client.is_authenticated
Proxy Support
For networks requiring proxy access:
proxies = {
'http': 'http://user:password@proxy.example.com:8080',
'https': 'http://user:password@proxy.example.com:8080',
}
client = Client(credentials, host="https://app.branchkey.com", proxies=proxies)
Development
Running Tests
git clone https://gitlab.com/branchkey/client_application.git
cd client_application
make local-test
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
python -m unittest -v
Support
BranchKey - Federated Learning Platform