🚨 Shai-Hulud Strikes Again:834 Packages Compromised.Technical Analysis →
Socket
Book a DemoInstallSign in
Socket

branchkey

Package Overview
Dependencies
Maintainers
1
Versions
25
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

branchkey

Client application to interface with the BranchKey system

pipPyPI
Version
2.7.0
Maintainers
1

BranchKey Python Client

BK_logo

PyPI version Python License: GPL v3

Official Python client for the BranchKey federated learning platform. This library provides a simple interface to upload model weights, download aggregated results, and track training runs.

Installation

pip install branchkey

Requirements: Python 3.9 or higher

Quick Start

1. Get Credentials

Create a leaf entity through the BranchKey platform to obtain credentials via the /v2/entities API endpoint:

credentials = {
    "id": "your-leaf-uuid",
    "name": "my-client",
    "session_token": "your-session-token-uuid",
    "owner_id": "your-user-uuid",
    "tree_id": "your-tree-uuid",
    "branch_id": "your-branch-uuid"
}

2. Initialize Client

from branchkey.client import Client

# Connect to BranchKey
client = Client(credentials, host="https://app.branchkey.com")

3. Upload Model Weights

import numpy as np

# Prepare model weights
weighting = 1000  # Weight for aggregation (typically number of samples)
parameters = [layer1_weights, layer2_weights, ...]

# Save and upload
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
print(f"Uploaded: {file_id}")

4. Download Aggregated Results

# Check for aggregation notifications
if not client.queue.empty():
    aggregation_id = client.queue.get(block=False)
    client.file_download(aggregation_id)
    # Downloaded to: ./aggregated_files/{aggregation_id}.npz

Configuration Options

client = Client(
    credentials,
    host="https://app.branchkey.com",  # API endpoint
    rbmq_host=None,                     # RabbitMQ host (auto-derived from host)
    rbmq_port=5672,                     # RabbitMQ port
    ssl=True,                           # Verify SSL certificates
    wait_for_run=False,                 # Wait if run is paused
    run_check_interval_s=30,            # Run status check interval
    proxies=None                        # HTTP/HTTPS proxy dict
)

Model Weight Format

Model weights are stored in compressed NPZ format.

Structure

# Format: (weighting, [list_of_parameter_arrays])
weighting = 1000  # Weight for aggregation (see below)
parameters = [layer1, layer2, ...]  # List of numpy arrays

Weighting Options

The weighting parameter controls how much influence this update has during aggregation:

1. By Sample Count (Most Common)

weighting = len(train_dataset)  # e.g., 1000 samples
# Client with 1000 samples has 2x influence of client with 500 samples

2. Equal Weighting

weighting = 1  # All clients have equal influence

3. Quality-Based Weighting

validation_accuracy = 0.85
weighting = len(train_dataset) * validation_accuracy  # Weight by quality

4. Manual Weighting

weighting = 5.0  # Trusted client gets higher weight

PyTorch Example

import numpy as np

# Method 1: Using client helper (recommended)
weighting = len(train_dataset)
parameters = []
for name, param in model.named_parameters():
    parameters.append(param.data.cpu().detach().numpy())

file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
# Method 2: Using convert_pytorch_numpy
weighting, parameters = client.convert_pytorch_numpy(
    model.named_parameters(),
    weighting=len(train_dataset)
)
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)

TensorFlow/Keras Example

import numpy as np

weighting = len(train_dataset)
parameters = [layer.numpy() for layer in model.trainable_weights]

file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)

Manual NPZ Creation

import numpy as np

# Save manually (without using client helper)
arrays_dict = {'weighting': np.array([weighting], dtype=np.float64)}
for i, arr in enumerate(parameters):
    arrays_dict[f'layer_{i}'] = arr

np.savez_compressed("model_weights.npz", **arrays_dict)  # Must include .npz

Loading Aggregated Weights

# Load aggregated weights from NPZ file
npz_data = np.load("aggregated_files/aggregation_id.npz")

# Note: Aggregated results only contain layers (no weighting)
layer_keys = sorted([k for k in npz_data.files if k.startswith('layer_')])
parameters = [npz_data[k] for k in layer_keys]

# Apply to your model
for i, param in enumerate(model.parameters()):
    param.data = torch.from_numpy(parameters[i])

Example NPZ File Contents

Client Upload Format (with weighting):

>>> npz_data.files
['weighting', 'layer_0', 'layer_1', 'layer_2', 'layer_3', ...]

>>> npz_data['weighting']
array([1530.])  # Weight for aggregation

>>> npz_data['layer_0'].shape, npz_data['layer_0'].dtype
((32, 1, 5, 5), dtype('float32'))

>>> npz_data['layer_0'][:1, :2, :2, :]
array([[[[-0.18576819, -0.03041792,  0.19532707, -0.11234483, -0.01512307],
         [ 0.19993757, -0.06492048,  0.08324468, -0.19899307, -0.0412709 ]]]],
       dtype=float32)

Aggregated Result Format (layers only):

>>> npz_data.files
['layer_0', 'layer_1', 'layer_2', ...]  # No weighting in aggregated results

Performance Metrics

Submit training or testing metrics:

import json

metrics = {"accuracy": 0.95, "loss": 0.12}
client.send_performance_metrics(
    aggregation_id="aggregation-uuid",
    data=json.dumps(metrics),
    mode="test"  # "test", "train", or "non-federated"
)

Client Properties

client.run_status        # Current run status: "start", "stop", or "pause"
client.run_number        # Current run iteration
client.leaf_id           # Your leaf UUID
client.branch_id         # Parent branch UUID
client.is_authenticated  # Authentication status

Proxy Support

For networks requiring proxy access:

proxies = {
    'http': 'http://user:password@proxy.example.com:8080',
    'https': 'http://user:password@proxy.example.com:8080',
}
client = Client(credentials, host="https://app.branchkey.com", proxies=proxies)

Development

Running Tests

# Clone repository
git clone https://gitlab.com/branchkey/client_application.git
cd client_application

# Run tests with Docker (requires Docker)
make local-test

# Or manually
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
python -m unittest -v

Support

BranchKey - Federated Learning Platform

FAQs

Did you know?

Socket

Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.

Install

Related posts