Huge News!Announcing our $40M Series B led by Abstract Ventures.Learn More
Socket
Sign inDemoInstall
Socket

qpax

Package Overview
Dependencies
Maintainers
1
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

qpax

Differentiable QP solver in JAX.

  • 0.0.9
  • PyPI
  • Socket score

Maintainers
1

qpax

Differentiable QP solver in JAX.

Paper

This package can be used for solving convex quadratic programs of the following form:

$$ \begin{align*} \underset{x}{\text{minimize}} & \quad \frac{1}{2}x^TQx + q^Tx \ \text{subject to} & \quad Ax = b, \ & \quad Gx \leq h \end{align*} $$

where $Q \succeq 0$. This solver can be combined with JAX's jit and vmap functionality, as well as differentiated with reverse-mode grad.

The QP is solved with a primal-dual interior point algorithm detailed in cvxgen, with the solution to the linear systems computed with reduction techniques from cvxopt. At an approximate primal-dual solution, the the primal variable $x$ is differentiated with respect to the problem parameters using the implicit function theorem as shown in optnet, and their pytorch-based qp solver qpth.

Installation

To install directly from github using pip:

$ pip install qpax

Alternatively, to install from source in editable mode:

$ pip install -e .

Usage

Solving a QP

We can solve QPs with qpax in a way that plays nice with JAX's jit and vmap:

import qpax

# solve QP (this can be combined with jit or vmap)
x, s, z, y, converged, iters = qpax.solve_qp(Q, q, A, b, G, h)

Solving a batch of QP's

Here let's solve a batch of nonnegative least squares problems as QPs. This outlines two bits of functionality from qpax, first is the ability to solve QPs without any equality constraints, and second is the ability to vmap over a batch of QPs.

import numpy as np
import jax 
import jax.numpy as jnp 
from jax import jit, grad, vmap  
import qpax 
import timeit

"""
solve batched non-negative least squares (nnls) problems
 
min_x    |Fx - g|^2 
st        x >= 0 
"""

n = 5   # size of x 
m = 10  # rows in F 

# create data for N_qps random nnls problems  
N_qps = 10000 
Fs = jnp.array(np.random.randn(N_qps, m, n))
gs = jnp.array(np.random.randn(N_qps, m))

@jit
def form_qp(F, g):
  # convert the least squares to qp form 
  n = F.shape[1]
  Q = F.T @ F 
  q = -F.T @ g 
  G = -jnp.eye(n)
  h = jnp.zeros(n)
  A = jnp.zeros((0, n))
  b = jnp.zeros(0)
  return Q, q, A, b, G, h

# create the QPs in a batched fashion 
Qs, qs, As, bs, Gs, hs = vmap(form_qp, in_axes = (0, 0))(Fs, gs)

# create function for solving a batch of QPs 
batch_qp = jit(vmap(qpax.solve_qp_primal, in_axes = (0, 0, 0, 0, 0, 0)))

xs = batch_qp(Qs, qs, As, bs, Gs, hs)

Differentiating a QP

Alternatively, if we are only looking to use the primal variable x, we can use solve_qp_primal which enables automatic differenation:

import jax 
import jax.numpy as jnp 
import qpax 

def loss(Q, q, A, b, G, h):
    x = qpax.solve_qp_primal(Q, q, A, b, G, h) 
    x_bar = jnp.ones(len(q))
    return jnp.dot(x - x_bar, x - x_bar)
  
# gradient of loss function   
loss_grad = jax.grad(loss, argnums = (0, 1, 2, 3, 4, 5))

# compatible with jit 
loss_grad_jit = jax.jit(loss_grad)

# calculate derivatives 
derivs = loss_grad_jit(Q, q, A, b, G, h)
dl_dQ, dl_dq, dl_dA, dl_db, dl_dG, dl_dh = derivs 

Citation

Paper

@misc{tracy2024differentiability,
    title={On the Differentiability of the Primal-Dual Interior-Point Method},
    author={Kevin Tracy and Zachary Manchester},
    year={2024},
    eprint={2406.11749},
    archivePrefix={arXiv},
    primaryClass={math.OC}
}

Keywords

FAQs


Did you know?

Socket

Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.

Install

Related posts

SocketSocket SOC 2 Logo

Product

  • Package Alerts
  • Integrations
  • Docs
  • Pricing
  • FAQ
  • Roadmap
  • Changelog

Packages

npm

Stay in touch

Get open source security insights delivered straight into your inbox.


  • Terms
  • Privacy
  • Security

Made with ⚡️ by Socket Inc