LinearOperator
LinearOperator is a PyTorch package for abstracting away the linear algebra routines needed for structured matrices (or operators).
This package is in beta.
Currently, most of the functionality only supports positive semi-definite and triangular matrices.
Package development TODOs:
To get started, run either
pip install linear_operator
conda install linear_operator -c gpytorch
or see below for more detailed instructions.
Why LinearOperator
Before describing what linear operators are and why they make a useful abstraction, it's easiest to see an example.
Let's say you wanted to compute a matrix solve:
$$\boldsymbol A^{-1} \boldsymbol b.$$
If you didn't know anything about the matrix $\boldsymbol A$, the simplest (and best) way to accomplish this in code is:
torch.linalg.solve(A, b)
While this is easy, the solve
routine is $\mathcal O(N^3)$, which gets very slow as $N$ grows large.
However, let's imagine that we knew that $\boldsymbol A$ was equal to a low rank matrix plus a diagonal
(i.e. $\boldsymbol A = \boldsymbol C \boldsymbol C^\top + \boldsymbol D$
for some skinny matrix $\boldsymbol C$ and some diagonal matrix $\boldsymbol D$.)
There's now a very efficient $\boldsymbol O(N)$ routine to compute $\boldsymbol A^{-1}$ (the Woodbury formula).
In general, if we know that $\boldsymbol A$ has structure,
we want to use efficient linear algebra routines - rather than the general routines -
that exploit this structure.
Without LinearOperator
Implementing the efficient solve that exploits $\boldsymbol A$'s low-rank-plus-diagonal structure would look something like this:
def low_rank_plus_diagonal_solve(C, d, b):
D_inv_b = b / d
D_inv_C = C / d.unsqueeze(-1)
eye = torch.eye(C.size(-2))
return (
D_inv_b - D_inv_C @ torch.cholesky_solve(
C.mT @ D_inv_b,
torch.linalg.cholesky(eye + C.mT @ D_inv_C, upper=False),
upper=False
)
)
low_rank_plus_diagonal_solve(C, d, b)
While this is efficient code, it's not ideal for a number of reasons:
- It's a lot more complicated than
torch.linalg.solve(A, b)
. - There's no object that represents $\boldsymbol A$.
To perform any math with $\boldsymbol A$, we have to pass around the matrix
C
and the vector d
.
With LinearOperator
The LinearOperator package offers the best of both worlds:
from linear_operator.operators import DiagLinearOperator, LowRankRootLinearOperator
A = LowRankRootLinearOperator(C) + DiagLinearOperator(d)
it provides an interface that lets us treat $\boldsymbol A$ as if it were a generic tensor,
using the standard PyTorch API:
torch.linalg.solve(A, b)
Under-the-hood, the LinearOperator
object keeps track of the algebraic structure of $\boldsymbol A$ (low rank plus diagonal)
and determines the most efficient routine to use (the Woodbury formula).
This way, we can get a efficient $\mathcal O(N)$ solve while abstracting away all of the details.
Crucially, $\boldsymbol A$ is never explicitly instantiated as a matrix, which makes it possible to scale
to very large operators without running out of memory:
A = LowRankRootLinearOperator(C) + DiagLinearOperator(d)
torch.linalg.solve(A, b)
What is a Linear Operator?
A linear operator is a generalization of a matrix.
It is a linear function that is defined in by its application to a vector.
The most common linear operators are (potentially structured) matrices,
where the function applying them to a vector are (potentially efficient)
matrix-vector multiplication routines.
In code, a LinearOperator
is a class that
- specifies the tensor(s) needed to define the LinearOperator,
- specifies a
_matmul
function (how the LinearOperator is applied to a vector), - specifies a
_size
function (how big is the LinearOperator if it is represented as a matrix, or batch of matrices), and - specifies a
_transpose_nonbatch
function (the adjoint of the LinearOperator). - (optionally) defines other functions (e.g.
logdet
, eigh
, etc.) to accelerate computations for which efficient sturcture-exploiting routines exist.
For example:
class DiagLinearOperator(linear_operator.LinearOperator):
r"""
A LinearOperator representing a diagonal matrix.
"""
def __init__(self, diag):
self.diag = diag
def _matmul(self, v):
return self.diag.unsqueeze(-1) * v
def _size(self):
return torch.Size([*self.diag.shape, self.diag.size(-1)])
def _transpose_nonbatch(self):
return self
def logdet(self):
return self.diag.log().sum(dim=-1)
D = DiagLinearOperator(torch.tensor([1., 2., 3.])
torch.matmul(D, torch.tensor([4., 5., 6.])
While _matmul
, _size
, and _transpose_nonbatch
might seem like a limited set of functions,
it turns out that most functions on the torch
and torch.linalg
namespaces can be efficiently implemented
using only these three primitative functions.
Moreover, because _matmul
is a linear function, it is very easy to compose linear operators in various ways.
For example: adding two linear operators (SumLinearOperator
) just requires adding the output of their _matmul
functions.
This makes it possible to define very complex compositional structures that still yield efficient linear algebraic routines.
Finally, LinearOperator
objects can be composed with one another, yielding new LinearOperator
objects and automatically keeping track of algebraic structure after each computation.
As a result, users never need to reason about what efficient linear algebra routines to use (so long as the input elements defined by the user encode known input structure).
See the using LinearOperator objects section for more details.
Use Cases
There are several use cases for the LinearOperator package.
Here we highlight two general themes:
Modular Code for Structured Matrices
For example, let's say that you have a generative model that involves
sampling from a high-dimensional multivariate Gaussian.
This sampling operation will require storing and manipulating a large covariance matrix,
so to speed things up you might want to experiment with different structured
approximations of that covariance matrix.
This is easy with the LinearOperator package.
from gpytorch.distributions import MultivariateNormal
cov = DiagLinearOperator(variance)
mvn = MultivariateNormal(torch.zeros(cov.size(-1), cov)
mvn.rsample()
Efficient Routines for Complex Operators
Many of the efficient linear algebra routines in LinearOperator are iterative algorithms
based on matrix-vector multiplication.
Since matrix-vector multiplication obeys many nice compositional properties
it is possible to obtain efficient routines for extremely complex compositional LienarOperators:
from linear_operator.operators import KroneckerProductLinearOperator, RootLinearOperator, ToeplitzLinearOperator
A = KroneckerProductLinearOperator(mat1, mat2) + RootLinearOperator(ToeplitzLinearOperator(vec3))
torch.linalg.solve(A, torch.randn(20000))
Using LinearOperator Objects
LinearOperator objects share (mostly) the same API as torch.Tensor
objects.
Under the hood, these objects use __torch_function__
to dispatch all efficient linear algebra operations
to the torch
and torch.linalg
namespaces.
This includes
torch.add
torch.cat
torch.clone
torch.diagonal
torch.dim
torch.div
torch.expand
torch.logdet
torch.matmul
torch.numel
torch.permute
torch.prod
torch.squeeze
torch.sub
torch.sum
torch.transpose
torch.unsqueeze
torch.linalg.cholesky
torch.linalg.eigh
torch.linalg.eigvalsh
torch.linalg.solve
torch.linalg.svd
Each of these functions will either return a torch.Tensor
, or a new LinearOperator
object,
depending on the function.
For example:
C = torch.matmul(A, B)
torch.linalg.solve(C, d)
For more examples, see the examples folder.
Batch Support and Broadcasting
LinearOperator
objects operate naturally in batch mode.
For example, to represent a batch of 3 100 x 100
diagonal matrices:
D = DiagLinearOperator(d)
These objects fully support broadcasted operations:
D @ torch.randn(100, 2)
D2 = DiagLinearOperator(torch.randn([2, 1, 100]))
D2 + D
Indexing
LinearOperator
objects can be indexed in ways similar to torch Tensors. This includes:
- Integer indexing (get a row, column, or batch)
- Slice indexing (get a subset of rows, columns, or batches)
- LongTensor indexing (get a set of individual entries by index)
- Ellipses (support indexing operations with arbitrary batch dimensions)
D = DiagLinearOperator(torch.randn(2, 3, 100))
D[-1]
D[..., :10, -5:]
D[..., torch.LongTensor([0, 1, 2, 3]), torch.LongTensor([0, 1, 2, 3])]
Composition and Decoration
LinearOperators can be composed with one another in various ways.
This includes
- Addition (
LinearOpA + LinearOpB
) - Matrix multiplication (
LinearOpA @ LinearOpB
) - Concatenation (
torch.cat([LinearOpA, LinearOpB], dim=-2)
) - Kronecker product (
torch.kron(LinearOpA, LinearOpB)
)
In addition, there are many ways to "decorate" LinearOperator objects.
This includes:
- Elementwise multiplying by constants (
torch.mul(2., LinearOpA)
) - Summing over batches (
torch.sum(LinearOpA, dim=-3)
) - Elementwise multiplying over batches (
torch.prod(LinearOpA, dim=-3)
)
See the documentation for a full list of supported composition and decoration operations.
Installation
LinearOperator requires Python >= 3.8.
Standard Installation (Most Recent Stable Version)
We recommend installing via pip
or Anaconda:
pip install linear_operator
conda install linear_operator -c gpytorch
The installation requires the following packages:
You can customize your PyTorch installation (i.e. CUDA version, CPU only option)
by following the PyTorch installation instructions.
Installing from the main
Branch (Latest Unsable Version)
To install what is currently on the main
branch (potentially buggy and unstable):
pip install --upgrade git+https://github.com/cornellius-gp/linear_operator.git
Development Installation
If you are contributing a pull request, it is best to perform a manual installation:
git clone https://github.com/cornellius-gp/linear_operator.git
cd linear_operator
pip install -e ".[dev,docs,test]"
Contributing
See the contributing guidelines CONTRIBUTING.md
for information on submitting issues and pull requests.
License
LinearOperator is MIT licensed.