🚀 Big News: Socket Acquires Coana to Bring Reachability Analysis to Every Appsec Team.Learn more
Socket
Book a DemoInstallSign in
Socket

torchmetrics-ext

Package Overview
Dependencies
Maintainers
1
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

torchmetrics-ext

An extension of torchmetrics package.

0.3.0
PyPI
Maintainers
1

TorchMetrics Extension

PyPI - Python Version PyPI version license

Installation

Simple installation from PyPI

pip install torchmetrics-ext

What is TorchMetrics Extension

It is an extension of torchmetrics containing more metrics for machine learning tasks. It offers:

  • A standardized interface to increase reproducibility
  • Reduces Boilerplate
  • Distributed-training compatible
  • Rigorously tested
  • Automatic accumulation over batches
  • Automatic synchronization between multiple devices

Currently, it offers metrics for:

Using TorchMetrics Extension

Here are examples for using the metrics in TorchMetrics Extension:

ScanRefer

Please download the ScanRefer dataset first, which will be required by the evaluator.

It measures the thresholded accuracy Acc@kIoU, where the positive predictions have higher intersection over union (IoU) with the ground truths than the thresholds. The metric is based on the ScanRefer task.

import torch
from torchmetrics_ext.metrics.visual_grounding import ScanReferMetric
metric = ScanReferMetric(dataset_file_path="./ScanRefer_filtered_val.json", split="validation")

# preds is a dictionary mapping each unique description identifier (formatted as "{scene_id}_{object_id}_{ann_id}")
# to the predicted axis-aligned bounding boxes in shape (2, 3)
preds = {
    "scene0011_00_0_0": torch.tensor([[0., 0., 0.], [0.5, 0.5, 0.5]]),
    "scene0011_01_0_1": torch.tensor([[0., 0., 0.], [1., 1., 1.]]),
}
metric(preds)

Nr3D

The dataset will be automatically downloaded from the official Nr3D Google Drive.

It measures the accuracy of selecting the target object from the candidates. The metric is based on the Nr3D task.

import torch
from torchmetrics_ext.metrics.visual_grounding import Nr3DMetric

metric = Nr3DMetric(split="test")

# indices of predicted and ground truth objects (B, )
pred_indices = torch.tensor([5, 2, 0, 0], dtype=torch.uint8)
gt_indices = torch.tensor([5, 5, 1, 0], dtype=torch.uint8)

gt_eval_types = (("easy", "view_dep"), ("easy", "view_indep"), ("hard", "view_dep"), ("hard", "view_dep"))
results = metric(pred_indices, gt_indices, gt_eval_types)

Multi3DRefer

The dataset will be automatically downloaded from the official Multi3DRefer Hugging Face repo.

It measures the F1-scores at multiple IoU thresholds (F1@kIoU), where the positive predictions have higher intersection over union (IoU) with the ground truths than the thresholds. The metric is based on the Multi3DRefer task.

import torch
from torchmetrics_ext.metrics.visual_grounding import Multi3DReferMetric
metric = Multi3DReferMetric(split="validation")

# preds is a dictionary mapping each unique description identifier (formatted as "{scene_id}_{ann_id}")
# to a variable number of predicted axis-aligned bounding boxes in shape (N, 2, 3)
preds = {
    "scene0011_00_0": torch.tensor([[[0., 0., 0.], [0.5, 0.5, 0.5]]]),  # 1 predicted box
    "scene0011_01_1": torch.tensor([[[0., 0., 0.], [1., 1., 1.]], [[0., 0., 0.], [2., 2., 2.]]])  # 2 predicted boxes
}
result = metric(preds)

Keywords

deep learning

FAQs

Did you know?

Socket

Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.

Install

Related posts