
Security News
New CVE Forecasting Tool Predicts 47,000 Disclosures in 2025
CVEForecast.org uses machine learning to project a record-breaking surge in vulnerability disclosures in 2025.
A generic interface for linear algebra backends: code it once, run it on any backend
pip install backends
The basic use case for the package is to write code that automatically determines the backend to use depending on the types of its arguments.
Example:
import lab as B
import lab.autograd # Load the AutoGrad extension.
import lab.torch # Load the PyTorch extension.
import lab.tensorflow # Load the TensorFlow extension.
import lab.jax # Load the JAX extension.
def objective(matrix):
outer_product = B.matmul(matrix, matrix, tr_b=True)
return B.mean(outer_product)
The AutoGrad, PyTorch, TensorFlow, and JAX extensions are not loaded automatically to
not enforce a dependency on all three frameworks.
An extension can alternatively be loaded via import lab.autograd as B
.
Run it with NumPy and AutoGrad:
>>> import autograd.numpy as np
>>> objective(B.randn(np.float64, 2, 2))
0.15772589216756833
>>> grad(objective)(B.randn(np.float64, 2, 2))
array([[ 0.23519042, -1.06282928],
[ 0.23519042, -1.06282928]])
Run it with TensorFlow:
>>> import tensorflow as tf
>>> objective(B.randn(tf.float64, 2, 2))
<tf.Tensor 'Mean:0' shape=() dtype=float64>
Run it with PyTorch:
>>> import torch
>>> objective(B.randn(torch.float64, 2, 2))
tensor(1.9557, dtype=torch.float64)
Run it with JAX:
>>> import jax
>>> import jax.numpy as jnp
>>> jax.jit(objective)(B.randn(jnp.float32, 2, 2))
DeviceArray(0.3109299, dtype=float32)
>>> jax.jit(jax.grad(objective))(B.randn(jnp.float32, 2, 2))
DeviceArray([[ 0.2525182, -1.26065 ],
[ 0.2525182, -1.26065 ]], dtype=float32)
This section lists all available types, which can be used to check types of objects or extend functions.
Int # Integers
Float # Floating-point numbers
Complex # Complex numbers
Bool # Booleans
Number # Numbers
Numeric # Numerical objects, including booleans
DType # Data type
Framework # Anything accepted by supported frameworks
Device # Any device type
NPNumeric
NPDType
NPRandomState
NP # Anything NumPy
AGNumeric
AGDType
AGRandomState
AG # Anything AutoGrad
TFNumeric
TFDType
TFRandomState
TFDevice
TF # Anything TensorFlow
TorchNumeric
TorchDType
TorchDevice
TorchRandomState
Torch # Anything PyTorch
JAXNumeric
JAXDType
JAXDevice
JAXRandomState
JAX # Anything JAX
This section lists all available constants and methods.
Arguments must be given as arguments and keyword arguments must be
given as keyword arguments.
For example, sum(tensor, axis=1)
is valid, but sum(tensor, 1)
is not.
The names of arguments are indicative of their function:
a
, b
, and c
indicate general tensors.dtype
indicates a data type. E.g, np.float32
or tf.float64
; and
rand(np.float32)
creates a NumPy random number, whereas
rand(tf.float64)
creates a TensorFlow random number.
Data types are always given as the first argument.shape
indicates a shape.
The dimensions of a shape are always given as separate arguments to
the function.
E.g., reshape(tensor, 2, 2)
is valid, but reshape(tensor, (2, 2))
is not.axis
indicates an axis over which the function may perform its action.
An axis is always given as a keyword argument.device
refers to a device on which a tensor can placed, which can
either be a framework-specific type or a string, e.g. "cpu"
.ref
indicates a reference tensor from which properties, like its
shape and data type, will be used. E.g., zeros(tensor)
creates a
tensor full of zeros of the same shape and data type as tensor
.See the documentation for more detailed descriptions of each function.
default_dtype # Default data type.
epsilon # Magnitude of diagonal to regularise matrices with.
cholesky_retry_factor # Retry the Cholesky, increasing `epsilon` by a factor at most this.
nan
pi
log_2_pi
dtype(a)
dtype_float(dtype)
dtype_float(a)
dtype_int(dtype)
dtype_int(a)
promote_dtypes(*dtype)
issubdtype(dtype1, dtype2)
isabstract(a)
jit(f, **kw_args)
isnan(a)
real(a)
imag(a)
device(a)
on_device(device)
on_device(a)
set_global_device(device)
to_active_device(a)
zeros(dtype, *shape)
zeros(*shape)
zeros(ref)
ones(dtype, *shape)
ones(*shape)
ones(ref)
zero(dtype)
zero(*refs)
one(dtype)
one(*refs)
eye(dtype, *shape)
eye(*shape)
eye(ref)
linspace(dtype, a, b, num)
linspace(a, b, num)
range(dtype, start, stop, step)
range(dtype, stop)
range(dtype, start, stop)
range(start, stop, step)
range(start, stop)
range(stop)
cast(dtype, a)
identity(a)
round(a)
floor(a)
ceil(a)
negative(a)
abs(a)
sign(a)
sqrt(a)
exp(a)
log(a)
log1p(a)
sin(a)
arcsin(a)
cos(a)
arccos(a)
tan(a)
arctan(a)
tanh(a)
arctanh(a)
loggamma(a)
logbeta(a)
erf(a)
sigmoid(a)
softplus(a)
relu(a)
add(a, b)
subtract(a, b)
multiply(a, b)
divide(a, b)
power(a, b)
minimum(a, b)
maximum(a, b)
leaky_relu(a, alpha)
softmax(a, axis=None)
min(a, axis=None, squeeze=True)
max(a, axis=None, squeeze=True)
sum(a, axis=None, squeeze=True)
prod(a, axis=None, squeeze=True)
mean(a, axis=None, squeeze=True)
std(a, axis=None, squeeze=True)
logsumexp(a, axis=None, squeeze=True)
all(a, axis=None, squeeze=True)
any(a, axis=None, squeeze=True)
nansum(a, axis=None, squeeze=True)
nanprod(a, axis=None, squeeze=True)
nanmean(a, axis=None, squeeze=True)
nanstd(a, axis=None, squeeze=True)
argmin(a, axis=None)
argmax(a, axis=None)
lt(a, b)
le(a, b)
gt(a, b)
ge(a, b)
eq(a, b)
ne(a, b)
bvn_cdf(a, b, c)
cond(condition, f_true, f_false, xs**)
where(condition, a, b)
scan(f, xs, *init_state)
sort(a, axis=-1, descending=False)
argsort(a, axis=-1, descending=False)
quantile(a, q, axis=None)
to_numpy(a)
jit_to_numpy(a) # Caches results for `B.jit`.
transpose(a, perm=None) (alias: t, T)
matmul(a, b, tr_a=False, tr_b=False) (alias: mm, dot)
einsum(equation, *elements)
trace(a, axis1=0, axis2=1)
kron(a, b)
svd(a, compute_uv=True)
eig(a, compute_eigvecs=True)
solve(a, b)
inv(a)
pinv(a)
det(a)
logdet(a)
expm(a)
logm(a)
cholesky(a) (alias: chol)
cholesky_solve(a, b) (alias: cholsolve)
triangular_solve(a, b, lower_a=True) (alias: trisolve)
toeplitz_solve(a, b, c) (alias: toepsolve)
toeplitz_solve(a, c)
outer(a, b)
reg(a, diag=None, clip=True)
pw_dists2(a, b)
pw_dists2(a)
pw_dists(a, b)
pw_dists(a)
ew_dists2(a, b)
ew_dists2(a)
ew_dists(a, b)
ew_dists(a)
pw_sums2(a, b)
pw_sums2(a)
pw_sums(a, b)
pw_sums(a)
ew_sums2(a, b)
ew_sums2(a)
ew_sums(a, b)
ew_sums(a)
set_random_seed(seed)
create_random_state(dtype, seed=0)
global_random_state(dtype)
global_random_state(a)
set_global_random_state(state)
rand(state, dtype, *shape)
rand(dtype, *shape)
rand(*shape)
rand(state, ref)
rand(ref)
randn(state, dtype, *shape)
randn(dtype, *shape)
randn(*shape)
randn(state, ref)
randn(ref)
randcat(state, p, *shape)
randcat(p, *shape)
choice(state, a, *shape, p=None)
choice(a, *shape, p=None)
randint(state, dtype, *shape, lower=0, upper)
randint(dtype, *shape, lower=0, upper)
randint(*shape, lower=0, upper)
randint(state, ref, lower=0, upper)
randint(ref, lower=0, upper)
randperm(state, dtype, n)
randperm(dtype, n)
randperm(n)
randgamma(state, dtype, *shape, alpha, scale)
randgamma(dtype, *shape, alpha, scale)
randgamma(*shape, alpha, scale)
randgamma(state, ref, *, alpha, scale)
randgamma(ref, *, alpha, scale)
randbeta(state, dtype, *shape, alpha, beta)
randbeta(dtype, *shape, alpha, beta)
randbeta(*shape, alpha, beta)
randbeta(state, ref, *, alpha, beta)
randbeta(ref, *, alpha, beta)
shape(a, *dims)
rank(a)
length(a) (alias: size)
is_scalar(a)
expand_dims(a, axis=0, times=1)
squeeze(a, axis=None)
uprank(a, rank=2)
downrank(a, rank=2, preserve=False)
broadcast_to(a, *shape)
diag(a)
diag_extract(a)
diag_construct(a)
flatten(a)
vec_to_tril(a, offset=0)
tril_to_vec(a, offset=0)
stack(*elements, axis=0)
unstack(a, axis=0, squeeze=True)
reshape(a, *shape)
concat(*elements, axis=0)
concat2d(*rows)
tile(a, *repeats)
take(a, indices_or_mask, axis=0)
submatrix(a, indices_or_mask)
You can get the device of a tensor with B.device(a)
,
and you can execute a computation on a device by entering B.on_device(device)
as a context:
with B.on_device("gpu:0"):
a = B.randn(tf.float32, 2, 2)
b = B.randn(tf.float32, 2, 2)
c = a @ b
Within such a context, a tensor that is not on the active device can be moved to the
active device with B.to_active_device(a)
.
You can also globally set the active device with B.set_global_device("gpu:0")
.
If a function is evaluated abstractly, then elements of the shape of a tensor, e.g.
B.shape(a)[0]
, may also be tensors, which can break dispatch.
By entering B.lazy_shapes()
, shapes and elements of shapes will be wrapped in a custom
type to fix this issue.
with B.lazy_shapes():
a = B.eye(2)
print(type(B.shape(a)))
# <class 'lab.shape.Shape'>
print(type(B.shape(a)[0]))
# <class 'lab.shape.Dimension'>
If you call a random number generator without providing a random state, e.g.
B.randn(np.float32, 2)
, the global random state from the corresponding
backend is used.
For JAX, since there is no global random state, LAB provides a JAX global
random state accessible through B.jax_global_random_state
once lab.jax
is loaded.
If you do not want to use a global random state but rather explicitly maintain
one, you can create a random state with B.create_random_state
and then
pass this as the first argument to the random number generators.
The random number generators will then return a tuple containing the updated
random state and the random result.
# Create random state.
state = B.create_random_state(tf.float32, seed=0)
# Generate two random arrays.
state, x = B.randn(state, tf.float32, 2)
state, y = B.randn(state, tf.float32, 2)
Coming soon!
FAQs
A generic interface for linear algebra backends
We found that backends demonstrated a healthy version release cadence and project activity because the last version was released less than a year ago. It has 1 open source maintainer collaborating on the project.
Did you know?
Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.
Security News
CVEForecast.org uses machine learning to project a record-breaking surge in vulnerability disclosures in 2025.
Security News
Browserslist-rs now uses static data to reduce binary size by over 1MB, improving memory use and performance for Rust-based frontend tools.
Research
Security News
Eight new malicious Firefox extensions impersonate games, steal OAuth tokens, hijack sessions, and exploit browser permissions to spy on users.