
Security News
Deno 2.6 + Socket: Supply Chain Defense In Your CLI
Deno 2.6 introduces deno audit with a new --socket flag that plugs directly into Socket to bring supply chain security checks into the Deno CLI.
open-metric-learning
Advanced tools
OML is a PyTorch-based framework to train and validate the models producing high-quality embeddings.
OML is a PyTorch-based framework to train and validate the models producing high-quality embeddings.
ㅤㅤ
ㅤㅤ
ㅤㅤ
ㅤㅤ
ㅤㅤ
There is a number of people from Oxford and HSE universities who have used OML in their theses. [1] [2] [3]
You may think "If I need image embeddings I can simply train a vanilla classifier and take its penultimate layer". Well, it makes sense as a starting point. But there are several possible drawbacks:
If you want to use embeddings to perform searching you need to calculate some distance among them (for example, cosine or L2). Usually, you don't directly optimize these distances during the training in the classification setup. So, you can only hope that final embeddings will have the desired properties.
The second problem is the validation process. In the searching setup, you usually care how related your top-N outputs are to the query. The natural way to evaluate the model is to simulate searching requests to the reference set and apply one of the retrieval metrics. So, there is no guarantee that classification accuracy will correlate with these metrics.
Finally, you may want to implement a metric learning pipeline by yourself. There is a lot of work: to use triplet loss you need to form batches in a specific way, implement different kinds of triplets mining, tracking distances, etc. For the validation, you also need to implement retrieval metrics, which include effective embeddings accumulation during the epoch, covering corner cases, etc. It's even harder if you have several gpus and use DDP. You may also want to visualize your search requests by highlighting good and bad search results. Instead of doing it by yourself, you can simply use OML for your purposes.
PML is the popular library for Metric Learning, and it includes a rich collection of losses, miners, distances, and reducers; that is why we provide straightforward examples of using them with OML. Initially, we tried to use PML, but in the end, we came up with our library, which is more pipeline / recipes oriented. That is how OML differs from PML:
OML has Pipelines which allows training models by preparing a config and your data in the required format (it's like converting data into COCO format to train a detector from mmdetection).
OML focuses on end-to-end pipelines and practical use cases. It has config based examples on popular benchmarks close to real life (like photos of products of thousands ids). We found some good combinations of hyperparameters on these datasets, trained and published models and their configs. Thus, it makes OML more recipes oriented than PML, and its author confirms this saying that his library is a set of tools rather the recipes, moreover, the examples in PML are mostly for CIFAR and MNIST datasets.
OML has the Zoo of pretrained models that can be easily accessed from
the code in the same way as in torchvision (when you type resnet50(pretrained=True)).
OML is integrated with PyTorch Lightning, so, we can use the power of its
Trainer.
This is especially helpful when we work with DDP, so, you compare our
DDP example
and the
PMLs one.
By the way, PML also has Trainers, but it's not
widely used in the examples and custom train / test functions are used instead.
We believe that having Pipelines, laconic examples, and Zoo of pretrained models sets the entry threshold to a really low value.
Metric Learning problem (also known as extreme classification problem) means a situation in which we have thousands of ids of some entities, but only a few samples for every entity. Often we assume that during the test stage (or production) we will deal with unseen entities which makes it impossible to apply the vanilla classification pipeline directly. In many cases obtained embeddings are used to perform search or matching procedures over them.
Here are a few examples of such tasks from the computer vision sphere:
embedding - model's output (also known as features vector or descriptor).query - a sample which is used as a request in the retrieval procedure.gallery set - the set of entities to search items similar to query (also known as reference or index).Sampler - an argument for DataLoader which is used to form batchesMiner - the object to form pairs or triplets after the batch was formed by Sampler. It's not necessary to form
the combinations of samples only inside the current batch, thus, the memory bank may be a part of Miner.Samples/Labels/Instances - as an example let's consider DeepFashion dataset. It includes thousands of
fashion item ids (we name them labels) and several photos for each item id
(we name the individual photo as instance or sample). All of the fashion item ids have their groups like
"skirts", "jackets", "shorts" and so on (we name them categories).
Note, we avoid using the term class to avoid misunderstanding.training epoch - batch samplers which we use for combination-based losses usually have a length equal to
[number of labels in training dataset] / [numbers of labels in one batch]. It means that we don't observe all of
the available training samples in one epoch (as opposed to vanilla classification),
instead, we observe all of the available labels.
It may be comparable with the current (2022 year) SotA methods, for example, Hyp-ViT. (Few words about this approach: it's a ViT architecture trained with contrastive loss, but the embeddings were projected into some hyperbolic space. As the authors claimed, such a space is able to describe the nested structure of real-world data. So, the paper requires some heavy math to adapt the usual operations for the hyperbolical space.)
We trained the same architecture with triplet loss, fixing the rest of the parameters: training and test transformations, image size, and optimizer. See configs in Models Zoo. The trick was in heuristics in our miner and sampler:
Category Balance Sampler forms the batches limiting the number of categories C in it. For instance, when C = 1 it puts only jackets in one batch and only jeans into another one (just an example). It automatically makes the negative pairs harder: it's more meaningful for a model to realise why two jackets are different than to understand the same about a jacket and a t-shirt.
Hard Triplets Miner makes the task even harder keeping only the hardest triplets (with maximal positive and minimal negative distances).
Here are CMC@1 scores for 2 popular benchmarks. SOP dataset: Hyp-ViT — 85.9, ours — 86.6. DeepFashion dataset: Hyp-ViT — 92.5, ours — 92.1. Thus, utilising simple heuristics and avoiding heavy math we are able to perform on SotA level.
Recent research in SSL definitely obtained great results. The problem is that these approaches required an enormous amount of computing to train the model. But in our framework, we consider the most common case when the average user has no more than a few GPUs.
At the same time, it would be unwise to ignore success in this sphere, so we still exploit it in two ways:
No, you don't. OML is a framework-agnostic. Despite we use PyTorch Lightning as a loop
runner for the experiments, we also keep the possibility to run everything on pure PyTorch.
Thus, only the tiny part of OML is Lightning-specific and we keep this logic separately from
other code (see oml.lightning). Even when you use Lightning, you don't need to know it, since
we provide ready to use Pipelines.
The possibility of using pure PyTorch and modular structure of the code leaves a room for utilizing OML with your favourite framework after the implementation of the necessary wrappers.
Yes. To run the experiment with Pipelines
you only need to write a converter
to our format (it means preparing the
.csv table with a few predefined columns).
That's it!
Probably we already have a suitable pre-trained model for your domain in our Models Zoo. In this case, you don't even need to train it.
Currently, we don't support exporting models to ONNX directly. However, you can use the built-in PyTorch capabilities to achieve this. For more information, please refer to this issue.
TUTORIAL TO START WITH: English | Russian | Chinese
The DEMO for our paper STIR: Siamese Transformers for Image Retrieval Postprocessing
Meet OpenMetricLearning (OML) on Marktechpost
The report for Berlin-based meetup: "Computer Vision in production". November, 2022. Link
pip install -U open-metric-learning # minimum dependencies
pip install -U open-metric-learning[nlp]
pip install -U open-metric-learning[audio]
pip install -U open-metric-learning[pipelines]
# in the case of conflicts install without dependencies and manage versions manually:
pip install --no-deps open-metric-learning
docker pull omlteam/oml:gpu
docker pull omlteam/oml:cpu
Losses |
Miners
|
Samplers
|
Configs support
|
Pre-trained models of different modalities
|
Post-processing
|
Post-processing by NN |
Paper
|
Logging
|
PML
|
Categories support
| Misc metrics
|
Lightning
|
Lightning DDP
|
Here is an example of how to train, validate and post-process the model on a tiny dataset of images, texts, or audios. See more details on dataset format.
SCROLL RIGHT FOR IMAGES > TEXTS > AUDIOS
| IMAGES | TEXTS | AUDIOS |
|
|
|
Output
|
Output
|
Output
|
Extra illustrations, explanations and tips for the code above.
Here is an inference time example (in other words, retrieval on test set). The code below works for both texts and images.
from oml.datasets import ImageQueryGalleryDataset
from oml.inference import inference
from oml.models import ViTExtractor
from oml.registry import get_transforms_for_pretrained
from oml.utils import get_mock_images_dataset
from oml.retrieval import RetrievalResults, AdaptiveThresholding
_, df_test = get_mock_images_dataset(global_paths=True)
del df_test["label"] # we don't need gt labels for doing predictions
extractor = ViTExtractor.from_pretrained("vits16_dino").to("cpu")
transform, _ = get_transforms_for_pretrained("vits16_dino")
dataset = ImageQueryGalleryDataset(df_test, transform=transform)
embeddings = inference(extractor, dataset, batch_size=4, num_workers=0)
rr = RetrievalResults.from_embeddings(embeddings, dataset, n_items=5)
rr = AdaptiveThresholding(n_std=3.5).process(rr)
rr.visualize(query_ids=[0, 1], dataset=dataset, show=True)
# you get the ids of retrieved items and the corresponding distances
print(rr)
Here is an example where queries and galleries processed separately.
import pandas as pd
from oml.datasets import ImageBaseDataset
from oml.inference import inference
from oml.models import ViTExtractor
from oml.registry import get_transforms_for_pretrained
from oml.retrieval import RetrievalResults, ConstantThresholding
from oml.utils import get_mock_images_dataset
extractor = ViTExtractor.from_pretrained("vits16_dino").to("cpu")
transform, _ = get_transforms_for_pretrained("vits16_dino")
paths = pd.concat(get_mock_images_dataset(global_paths=True))["path"]
galleries, queries1, queries2 = paths[:20], paths[20:22], paths[22:24]
# gallery is huge and fixed, so we only process it once
dataset_gallery = ImageBaseDataset(galleries, transform=transform)
embeddings_gallery = inference(extractor, dataset_gallery, batch_size=4, num_workers=0)
# queries come "online" in stream
for queries in [queries1, queries2]:
dataset_query = ImageBaseDataset(queries, transform=transform)
embeddings_query = inference(extractor, dataset_query, batch_size=4, num_workers=0)
# for the operation below we are going to provide integrations with vector search DB like QDrant or Faiss
rr = RetrievalResults.from_embeddings_qg(
embeddings_query=embeddings_query, embeddings_gallery=embeddings_gallery,
dataset_query=dataset_query, dataset_gallery=dataset_gallery
)
rr = ConstantThresholding(th=80).process(rr)
rr.visualize_qg([0, 1], dataset_query=dataset_query, dataset_gallery=dataset_gallery, show=True)
print(rr)
Pipelines provide a way to run metric learning experiments via changing only the config file. All you need is to prepare your dataset in a required format.
See Pipelines folder for more details:
You can use an image model from our Zoo or use other arbitrary models after you inherited it from IExtractor.
from oml.const import CKPT_SAVE_ROOT as CKPT_DIR, MOCK_DATASET_PATH as DATA_DIR
from oml.models import ViTExtractor
from oml.registry import get_transforms_for_pretrained
model = ViTExtractor.from_pretrained("vits16_dino").eval()
transforms, im_reader = get_transforms_for_pretrained("vits16_dino")
img = im_reader(DATA_DIR / "images" / "circle_1.jpg") # put path to your image here
img_tensor = transforms(img)
# img_tensor = transforms(image=img)["image"] # for transforms from Albumentations
features = model(img_tensor.unsqueeze(0))
# Check other available models:
print(list(ViTExtractor.pretrained_models.keys()))
# Load checkpoint saved on a disk:
model_ = ViTExtractor(weights=CKPT_DIR / "vits16_dino.ckpt", arch="vits16", normalise_features=False)
Models, trained by us. The metrics below are for 224 x 224 images:
| model | cmc1 | dataset | weights | experiment |
|---|---|---|---|---|
ViTExtractor.from_pretrained("vits16_inshop") | 0.921 | DeepFashion Inshop | link | link |
ViTExtractor.from_pretrained("vits16_sop") | 0.866 | Stanford Online Products | link | link |
ViTExtractor.from_pretrained("vits16_cars") | 0.907 | CARS 196 | link | link |
ViTExtractor.from_pretrained("vits16_cub") | 0.837 | CUB 200 2011 | link | link |
Models, trained by other researchers.
Note, that some metrics on particular benchmarks are so high because they were part of the training dataset (for example unicom).
The metrics below are for 224 x 224 images:
| model | Stanford Online Products | DeepFashion InShop | CUB 200 2011 | CARS 196 |
|---|---|---|---|---|
ViTUnicomExtractor.from_pretrained("vitb16_unicom") | 0.700 | 0.734 | 0.847 | 0.916 |
ViTUnicomExtractor.from_pretrained("vitb32_unicom") | 0.690 | 0.722 | 0.796 | 0.893 |
ViTUnicomExtractor.from_pretrained("vitl14_unicom") | 0.726 | 0.790 | 0.868 | 0.922 |
ViTUnicomExtractor.from_pretrained("vitl14_336px_unicom") | 0.745 | 0.810 | 0.875 | 0.924 |
ViTCLIPExtractor.from_pretrained("sber_vitb32_224") | 0.547 | 0.514 | 0.448 | 0.618 |
ViTCLIPExtractor.from_pretrained("sber_vitb16_224") | 0.565 | 0.565 | 0.524 | 0.648 |
ViTCLIPExtractor.from_pretrained("sber_vitl14_224") | 0.512 | 0.555 | 0.606 | 0.707 |
ViTCLIPExtractor.from_pretrained("openai_vitb32_224") | 0.612 | 0.491 | 0.560 | 0.693 |
ViTCLIPExtractor.from_pretrained("openai_vitb16_224") | 0.648 | 0.606 | 0.665 | 0.767 |
ViTCLIPExtractor.from_pretrained("openai_vitl14_224") | 0.670 | 0.675 | 0.745 | 0.844 |
ViTExtractor.from_pretrained("vits16_dino") | 0.648 | 0.509 | 0.627 | 0.265 |
ViTExtractor.from_pretrained("vits8_dino") | 0.651 | 0.524 | 0.661 | 0.315 |
ViTExtractor.from_pretrained("vitb16_dino") | 0.658 | 0.514 | 0.541 | 0.288 |
ViTExtractor.from_pretrained("vitb8_dino") | 0.689 | 0.599 | 0.506 | 0.313 |
ViTExtractor.from_pretrained("vits14_dinov2") | 0.566 | 0.334 | 0.797 | 0.503 |
ViTExtractor.from_pretrained("vits14_reg_dinov2") | 0.566 | 0.332 | 0.795 | 0.740 |
ViTExtractor.from_pretrained("vitb14_dinov2") | 0.565 | 0.342 | 0.842 | 0.644 |
ViTExtractor.from_pretrained("vitb14_reg_dinov2") | 0.557 | 0.324 | 0.833 | 0.828 |
ViTExtractor.from_pretrained("vitl14_dinov2") | 0.576 | 0.352 | 0.844 | 0.692 |
ViTExtractor.from_pretrained("vitl14_reg_dinov2") | 0.571 | 0.340 | 0.840 | 0.871 |
ResnetExtractor.from_pretrained("resnet50_moco_v2") | 0.493 | 0.267 | 0.264 | 0.149 |
ResnetExtractor.from_pretrained("resnet50_imagenet1k_v1") | 0.515 | 0.284 | 0.455 | 0.247 |
The metrics may be different from the ones reported by papers, because the version of train/val split and usage of bounding boxes may differ.
Here is a lightweight integration with HuggingFace Transformers models. You can replace it with other arbitrary models inherited from IExtractor.
pip install open-metric-learning[nlp]
from transformers import AutoModel, AutoTokenizer
from oml.models import HFWrapper
model = AutoModel.from_pretrained('bert-base-uncased').eval()
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
extractor = HFWrapper(model=model, feat_dim=768)
inp = tokenizer(text="Hello world", return_tensors="pt", add_special_tokens=True)
embeddings = extractor(inp)
Note, we don't have our own text models zoo at the moment.
You can use an audio model from our Zoo or use other arbitrary models after you inherited it from IExtractor.
pip install open-metric-learning[audio]
import torchaudio
from oml.models import ECAPATDNNExtractor
from oml.const import CKPT_SAVE_ROOT as CKPT_DIR, MOCK_AUDIO_DATASET_PATH as DATA_DIR
# replace it by your actual paths
ckpt_path = CKPT_DIR / "ecapa_tdnn_taoruijie.pth"
file_path = DATA_DIR / "voices" / "voice0_0.wav"
model = ECAPATDNNExtractor(weights=ckpt_path, arch="ecapa_tdnn_taoruijie", normalise_features=False).to("cpu").eval()
audio, sr = torchaudio.load(file_path)
if audio.shape[0] > 1:
audio = audio.mean(dim=0, keepdim=True) # mean by channels
if sr != 16000:
audio = torchaudio.functional.resample(audio, sr, 16000)
embeddings = model.extract(audio)
| model | Vox1_O | Vox1_E | Vox1_H |
|---|---|---|---|
ECAPATDNNExtractor.from_pretrained("ecapa_tdnn_taoruijie") | 0.86 | 1.18 | 2.17 |
The metrics above represent Equal Error Rate (EER). Lower is better.
We welcome new contributors! Please, see our:
The project was started in 2020 as a module for Catalyst library. I want to thank people who worked with me on that module: Julia Shenshina, Nikita Balagansky, Sergey Kolesnikov and others.
I would like to thank people who continue working on this pipeline when it became a separate project: Julia Shenshina, Misha Kindulov, Aron Dik, Aleksei Tarasov and Verkhovtsev Leonid.
I also want to thank NewYorker, since the part of functionality was developed (and used) by its computer vision team led by me.
FAQs
OML is a PyTorch-based framework to train and validate the models producing high-quality embeddings.
We found that open-metric-learning demonstrated a healthy version release cadence and project activity because the last version was released less than a year ago. It has 1 open source maintainer collaborating on the project.
Did you know?

Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.

Security News
Deno 2.6 introduces deno audit with a new --socket flag that plugs directly into Socket to bring supply chain security checks into the Deno CLI.

Security News
New DoS and source code exposure bugs in React Server Components and Next.js: what’s affected and how to update safely.

Security News
Socket CEO Feross Aboukhadijeh joins Software Engineering Daily to discuss modern software supply chain attacks and rising AI-driven security risks.