Huge News!Announcing our $40M Series B led by Abstract Ventures.Learn More
Socket
Sign inDemoInstall
Socket

rax

Package Overview
Dependencies
Maintainers
1
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

rax

Learning-to-Rank using JAX.

  • 0.3.0
  • PyPI
  • Socket score

Maintainers
1

🦖 Rax: Learning-to-Rank using JAX

Docs PyPI License

Rax is a Learning-to-Rank library written in JAX. Rax provides off-the-shelf implementations of ranking losses and metrics to be used with JAX. It provides the following functionality:

  • Ranking losses (rax.*_loss): rax.softmax_loss, rax.pairwise_logistic_loss, ...
  • Ranking metrics (rax.*_metric): rax.mrr_metric, rax.ndcg_metric, ...
  • Transformations (rax.*_t12n): rax.approx_t12n, rax.gumbel_t12n, ...

Ranking

A ranking problem is different from traditional classification/regression problems in that its objective is to optimize for the correctness of the relative order of a list of examples (e.g., documents) for a given context (e.g., a query). Rax provides support for ranking problems within the JAX ecosystem. It can be used in, but is not limited to, the following applications:

  • Search: ranking a list of documents with respect to a query.
  • Recommendation: ranking a list of items given a user as context.
  • Question Answering: finding the best answer from a list of candidates.
  • Dialogue System: finding the best response from a list of responses.

Synopsis

In a nutshell, given the scores and labels for a list of items, Rax can compute various ranking losses and metrics:

import jax.numpy as jnp
import rax

scores = jnp.array([2.2, -1.3, 5.4])  # output of a model.
labels = jnp.array([1.0,  0.0, 0.0])  # indicates doc 1 is relevant.

rax.ndcg_metric(scores, labels)  # computes a ranking metric.
# 0.63092977

rax.pairwise_hinge_loss(scores, labels)  # computes a ranking loss.
# 2.1

All of the Rax losses and metrics are purely functional and compose well with standard JAX transformations. Additionally, Rax provides ranking-specific transformations so you can build new ranking losses. An example is rax.approx_t12n, which can be used to transform any (non-differentiable) ranking metric into a differentiable loss. For example:

loss_fn = rax.approx_t12n(rax.ndcg_metric)
loss_fn(scores, labels)  # differentiable approx ndcg loss.
# -0.63282484

jax.grad(loss_fn)(scores, labels)  # computes gradients w.r.t. scores.
# [-0.01276882  0.00549765  0.00727116]

Installation

See https://github.com/google/jax#installation for instructions on installing JAX.

We suggest installing the latest stable version of Rax by running:

$ pip install rax

Examples

See the examples/ directory for complete examples on how to use Rax.

Citing Rax

If you use Rax, please consider citing our paper:

@inproceedings{jagerman2022rax,
  title = {Rax: Composable Learning-to-Rank using JAX},
  author  = {Rolf Jagerman and Xuanhui Wang and Honglei Zhuang and Zhen Qin and
  Michael Bendersky and Marc Najork},
  year  = {2022},
  booktitle = {Proceedings of the 28th ACM SIGKDD International Conference on Knowledge Discovery \& Data Mining}
}

Keywords

FAQs


Did you know?

Socket

Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.

Install

Related posts

SocketSocket SOC 2 Logo

Product

  • Package Alerts
  • Integrations
  • Docs
  • Pricing
  • FAQ
  • Roadmap
  • Changelog

Packages

npm

Stay in touch

Get open source security insights delivered straight into your inbox.


  • Terms
  • Privacy
  • Security

Made with ⚡️ by Socket Inc