Socket
Socket
Sign inDemoInstall

kecam

Package Overview
Dependencies
5
Maintainers
1
Alerts
File Explorer

Install Socket

Detect and block malicious and high-risk dependencies

Install

    kecam

Tensorflow keras computer vision attention models. Alias kecam. https://github.com/leondgarse/keras_cv_attention_models


Maintainers
1

Readme

Keras_cv_attention_models


  • WARNING: currently NOT compatible with keras 3.0, if using tensorflow>=2.16.0, needs to install pip install tf-keras~=2.16 manually.
  • coco_train_script.py is under testing. Still struggling for this...
  • RepViT architecture is changed adapting new weights since kecam>1.3.22

General Usage

Basic

  • Currently recommended TF version is tensorflow==2.11.1. Expecially for training or TFLite conversion.
  • Default import will not specific these while using them in READMEs.
    import os
    import sys
    import tensorflow as tf
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from tensorflow import keras
    
  • Install as pip package. kecam is a short alias name of this package. Note: the pip package kecam doesn't set any backend requirement, make sure either Tensorflow or PyTorch installed before hand. For PyTorch backend usage, refer Keras PyTorch Backend.
    pip install -U kecam
    # Or
    pip install -U keras-cv-attention-models
    # Or
    pip install -U git+https://github.com/leondgarse/keras_cv_attention_models
    
    Refer to each sub directory for detail usage.
  • Basic model prediction
    from keras_cv_attention_models import volo
    mm = volo.VOLO_d1(pretrained="imagenet")
    
    """ Run predict """
    import tensorflow as tf
    from tensorflow import keras
    from keras_cv_attention_models.test_images import cat
    img = cat()
    imm = keras.applications.imagenet_utils.preprocess_input(img, mode='torch')
    pred = mm(tf.expand_dims(tf.image.resize(imm, mm.input_shape[1:3]), 0)).numpy()
    pred = tf.nn.softmax(pred).numpy()  # If classifier activation is not softmax
    print(keras.applications.imagenet_utils.decode_predictions(pred)[0])
    # [('n02124075', 'Egyptian_cat', 0.99664897),
    #  ('n02123045', 'tabby', 0.0007249644),
    #  ('n02123159', 'tiger_cat', 0.00020345),
    #  ('n02127052', 'lynx', 5.4973923e-05),
    #  ('n02123597', 'Siamese_cat', 2.675306e-05)]
    
    Or just use model preset preprocess_input and decode_predictions
    from keras_cv_attention_models import coatnet
    mm = coatnet.CoAtNet0()
    
    from keras_cv_attention_models.test_images import cat
    preds = mm(mm.preprocess_input(cat()))
    print(mm.decode_predictions(preds))
    # [[('n02124075', 'Egyptian_cat', 0.9999875), ('n02123045', 'tabby', 5.194884e-06), ...]]
    
    The preset preprocess_input and decode_predictions also compatible with PyTorch backend.
    os.environ['KECAM_BACKEND'] = 'torch'
    
    from keras_cv_attention_models import caformer
    mm = caformer.CAFormerS18()
    # >>>> Using PyTorch backend
    # >>>> Aligned input_shape: [3, 224, 224]
    # >>>> Load pretrained from: ~/.keras/models/caformer_s18_224_imagenet.h5
    
    from keras_cv_attention_models.test_images import cat
    preds = mm(mm.preprocess_input(cat()))
    print(preds.shape)
    # torch.Size([1, 1000])
    print(mm.decode_predictions(preds))
    # [[('n02124075', 'Egyptian_cat', 0.8817097), ('n02123045', 'tabby', 0.009335292), ...]]
    
  • num_classes=0 set for excluding model top GlobalAveragePooling2D + Dense layers.
    from keras_cv_attention_models import resnest
    mm = resnest.ResNest50(num_classes=0)
    print(mm.output_shape)
    # (None, 7, 7, 2048)
    
  • num_classes={custom output classes} others than 1000 or 0 will just skip loading the header Dense layer weights. As model.load_weights(weight_file, by_name=True, skip_mismatch=True) is used for loading weights.
    from keras_cv_attention_models import swin_transformer_v2
    
    mm = swin_transformer_v2.SwinTransformerV2Tiny_window8(num_classes=64)
    # >>>> Load pretrained from: ~/.keras/models/swin_transformer_v2_tiny_window8_256_imagenet.h5
    # WARNING:tensorflow:Skipping loading weights for layer #601 (named predictions) due to mismatch in shape for weight predictions/kernel:0. Weight expects shape (768, 64). Received saved weight with shape (768, 1000)
    # WARNING:tensorflow:Skipping loading weights for layer #601 (named predictions) due to mismatch in shape for weight predictions/bias:0. Weight expects shape (64,). Received saved weight with shape (1000,)
    
  • Reload own model weights by set pretrained="xxx.h5". Better than calling model.load_weights directly, if reloading model with different input_shape and with weights shape not matching.
    import os
    from keras_cv_attention_models import coatnet
    pretrained = os.path.expanduser('~/.keras/models/coatnet0_224_imagenet.h5')
    mm = coatnet.CoAtNet1(input_shape=(384, 384, 3), pretrained=pretrained)  # No sense, just showing usage
    
  • Alias name kecam can be used instead of keras_cv_attention_models. It's __init__.py only with from keras_cv_attention_models import *.
    import kecam
    mm = kecam.yolor.YOLOR_CSP()
    imm = kecam.test_images.dog_cat()
    preds = mm(mm.preprocess_input(imm))
    bboxs, lables, confidences = mm.decode_predictions(preds)[0]
    kecam.coco.show_image_with_bboxes(imm, bboxs, lables, confidences)
    
  • Calculate flops method from TF 2.0 Feature: Flops calculation #32809. For PyTorch backend, needs thop pip install thop.
    from keras_cv_attention_models import coatnet, resnest, model_surgery
    
    model_surgery.get_flops(coatnet.CoAtNet0())
    # >>>> FLOPs: 4,221,908,559, GFLOPs: 4.2219G
    model_surgery.get_flops(resnest.ResNest50())
    # >>>> FLOPs: 5,378,399,992, GFLOPs: 5.3784G
    
  • [Deprecated] tensorflow_addons is not imported by default. While reloading model depending on GroupNormalization like MobileViTV2 from h5 directly, needs to import tensorflow_addons manually first.
    import tensorflow_addons as tfa
    
    model_path = os.path.expanduser('~/.keras/models/mobilevit_v2_050_256_imagenet.h5')
    mm = keras.models.load_model(model_path)
    
  • Export TF model to onnx. Needs tf2onnx for TF, pip install onnx tf2onnx onnxsim onnxruntime. For using PyTorch backend, exporting onnx is supported by PyTorch.
    from keras_cv_attention_models import volo, nat, model_surgery
    mm = nat.DiNAT_Small(pretrained=True)
    model_surgery.export_onnx(mm, fuse_conv_bn=True, batch_size=1, simplify=True)
    # Exported simplified onnx: dinat_small.onnx
    
    # Run test
    from keras_cv_attention_models.imagenet import eval_func
    aa = eval_func.ONNXModelInterf(mm.name + '.onnx')
    inputs = np.random.uniform(size=[1, *mm.input_shape[1:]]).astype('float32')
    print(f"{np.allclose(aa(inputs), mm(inputs), atol=1e-5) = }")
    # np.allclose(aa(inputs), mm(inputs), atol=1e-5) = True
    
  • Model summary model_summary.csv contains gathered model info.
    • params for model params count in M
    • flops for FLOPs in G
    • input for model input shape
    • acc_metrics means Imagenet Top1 Accuracy for recognition models, COCO val AP for detection models
    • inference_qps for T4 inference query per second with batch_size=1 + trtexec
    • extra means if any extra training info.
    from keras_cv_attention_models import plot_func
    plot_series = [
        "efficientnetv2", 'tinynet', 'lcnet', 'mobilenetv3', 'fasternet', 'fastervit', 'ghostnet',
        'inceptionnext', 'efficientvit_b', 'mobilevit', 'convnextv2', 'efficientvit_m', 'hiera',
    ]
    plot_func.plot_model_summary(
        plot_series, model_table="model_summary.csv", log_scale_x=True, allow_extras=['mae_in1k_ft1k']
    )
    
    model_summary
  • Code format is using line-length=160:
    find ./* -name "*.py" | grep -v __init__ | xargs -I {} black -l 160 {}
    

T4 Inference

  • T4 Inference in the model tables are tested using trtexec on Tesla T4 with CUDA=12.0.1-1, Driver=525.60.13. All models are exported as ONNX using PyTorch backend, using batch_szie=1 only. Note: this data is for reference only, and vary in different batch sizes or benchmark tools or platforms or implementations.
  • All results are tested using colab trtexec.ipynb. Thus reproducible by any others.
os.environ["KECAM_BACKEND"] = "torch"

from keras_cv_attention_models import convnext, test_images, imagenet
# >>>> Using PyTorch backend
mm = convnext.ConvNeXtTiny()
mm.export_onnx(simplify=True)
# Exported onnx: convnext_tiny.onnx
# Running onnxsim.simplify...
# Exported simplified onnx: convnext_tiny.onnx

# Onnx run test
tt = imagenet.eval_func.ONNXModelInterf('convnext_tiny.onnx')
print(mm.decode_predictions(tt(mm.preprocess_input(test_images.cat()))))
# [[('n02124075', 'Egyptian_cat', 0.880507), ('n02123045', 'tabby', 0.0047998047), ...]]

""" Run trtexec benchmark """
!trtexec --onnx=convnext_tiny.onnx --fp16 --allowGPUFallback --useSpinWait --useCudaGraph

Layers

  • attention_layers is __init__.py only, which imports core layers defined in model architectures. Like RelativePositionalEmbedding from botnet, outlook_attention from volo, and many other Positional Embedding Layers / Attention Blocks.
from keras_cv_attention_models import attention_layers
aa = attention_layers.RelativePositionalEmbedding()
print(f"{aa(tf.ones([1, 4, 14, 16, 256])).shape = }")
# aa(tf.ones([1, 4, 14, 16, 256])).shape = TensorShape([1, 4, 14, 16, 14, 16])

Model surgery

  • model_surgery including functions used to change model parameters after built.
from keras_cv_attention_models import model_surgery
mm = keras.applications.ResNet50()  # Trainable params: 25,583,592

# Replace all ReLU with PReLU. Trainable params: 25,606,312
mm = model_surgery.replace_ReLU(mm, target_activation='PReLU')

# Fuse conv and batch_norm layers. Trainable params: 25,553,192
mm = model_surgery.convert_to_fused_conv_bn_model(mm)

ImageNet training and evaluating

  • ImageNet contains more detail usage and some comparing results.
  • Init Imagenet dataset using tensorflow_datasets #9.
  • For custom dataset, custom_dataset_script.py can be used creating a json format file, which can be used as --data_name xxx.json for training, detail usage can be found in Custom recognition dataset.
  • Another method creating custom dataset is using tfds.load, refer Writing custom datasets and Creating private tensorflow_datasets from tfds #48 by @Medicmind.
  • Running an AWS Sagemaker estimator job using keras_cv_attention_models can be found in AWS Sagemaker script example by @Medicmind.
  • aotnet.AotNet50 default parameters set is a typical ResNet50 architecture with Conv2D use_bias=False and padding like PyTorch.
  • Default parameters for train_script.py is like A3 configuration from ResNet strikes back: An improved training procedure in timm with batch_size=256, input_shape=(160, 160).
    # `antialias` is default enabled for resize, can be turned off be set `--disable_antialias`.
    CUDA_VISIBLE_DEVICES='0' TF_XLA_FLAGS="--tf_xla_auto_jit=2" python3 train_script.py --seed 0 -s aotnet50
    
    # Evaluation using input_shape (224, 224).
    # `antialias` usage should be same with training.
    CUDA_VISIBLE_DEVICES='1' python3 eval_script.py -m aotnet50_epoch_103_val_acc_0.7674.h5 -i 224 --central_crop 0.95
    # >>>> Accuracy top1: 0.78466 top5: 0.94088
    
    aotnet50_imagenet
  • Restore from break point by setting --restore_path and --initial_epoch, and keep other parameters same. restore_path is higher priority than model and additional_model_kwargs, also restore optimizer and loss. initial_epoch is mainly for learning rate scheduler. If not sure where it stopped, check checkpoints/{save_name}_hist.json.
    import json
    with open("checkpoints/aotnet50_hist.json", "r") as ff:
        aa = json.load(ff)
    len(aa['lr'])
    # 41 ==> 41 epochs are finished, initial_epoch is 41 then, restart from epoch 42
    
    CUDA_VISIBLE_DEVICES='0' TF_XLA_FLAGS="--tf_xla_auto_jit=2" python3 train_script.py --seed 0 -r checkpoints/aotnet50_latest.h5 -I 41
    # >>>> Restore model from: checkpoints/aotnet50_latest.h5
    # Epoch 42/105
    
  • eval_script.py is used for evaluating model accuracy. EfficientNetV2 self tested imagenet accuracy #19 just showing how different parameters affecting model accuracy.
    # evaluating pretrained builtin model
    CUDA_VISIBLE_DEVICES='1' python3 eval_script.py -m regnet.RegNetZD8
    # evaluating pretrained timm model
    CUDA_VISIBLE_DEVICES='1' python3 eval_script.py -m timm.models.resmlp_12_224 --input_shape 224
    
    # evaluating specific h5 model
    CUDA_VISIBLE_DEVICES='1' python3 eval_script.py -m checkpoints/xxx.h5
    # evaluating specific tflite model
    CUDA_VISIBLE_DEVICES='1' python3 eval_script.py -m xxx.tflite
    
  • Progressive training refer to PDF 2104.00298 EfficientNetV2: Smaller Models and Faster Training. AotNet50 A3 progressive input shapes 96 128 160:
    CUDA_VISIBLE_DEVICES='1' TF_XLA_FLAGS="--tf_xla_auto_jit=2" python3 progressive_train_script.py \
    --progressive_epochs 33 66 -1 \
    --progressive_input_shapes 96 128 160 \
    --progressive_magnitudes 2 4 6 \
    -s aotnet50_progressive_3_lr_steps_100 --seed 0
    
    aotnet50_progressive_160
  • Transfer learning with freeze_backbone or freeze_norm_layers: EfficientNetV2B0 transfer learning on cifar10 testing freezing backbone #55.
  • Token label train test on CIFAR10 #57. Currently not working as well as expected. Token label is implementation of Github zihangJiang/TokenLabeling, paper PDF 2104.10858 All Tokens Matter: Token Labeling for Training Better Vision Transformers.

COCO training and evaluating

  • Currently still under testing.

  • COCO contains more detail usage.

  • custom_dataset_script.py can be used creating a json format file, which can be used as --data_name xxx.json for training, detail usage can be found in Custom detection dataset.

  • Default parameters for coco_train_script.py is EfficientDetD0 with input_shape=(256, 256, 3), batch_size=64, mosaic_mix_prob=0.5, freeze_backbone_epochs=32, total_epochs=105. Technically, it's any pyramid structure backbone + EfficientDet / YOLOX header / YOLOR header + anchor_free / yolor / efficientdet anchors combination supported.

  • Currently 4 types anchors supported, parameter anchors_mode controls which anchor to use, value in ["efficientdet", "anchor_free", "yolor", "yolov8"]. Default None for det_header presets.

  • NOTE: YOLOV8 has a default regression_len=64 for bbox output length. Typically it's 4 for other detection models, for yolov8 it's reg_max=16 -> regression_len = 16 * 4 == 64.

    anchors_modeuse_object_scoresnum_anchorsanchor_scaleaspect_ratiosnum_scalesgrid_zero_start
    efficientdetFalse94[1, 2, 0.5]3False
    anchor_freeTrue11[1]1True
    yolorTrue3NonepresetsNoneoffset=0.5
    yolov8False11[1]1False
    # Default EfficientDetD0
    CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py
    # Default EfficientDetD0 using input_shape 512, optimizer adamw, freezing backbone 16 epochs, total 50 + 5 epochs
    CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py -i 512 -p adamw --freeze_backbone_epochs 16 --lr_decay_steps 50
    
    # EfficientNetV2B0 backbone + EfficientDetD0 detection header
    CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --backbone efficientnet.EfficientNetV2B0 --det_header efficientdet.EfficientDetD0
    # ResNest50 backbone + EfficientDetD0 header using yolox like anchor_free anchors
    CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --backbone resnest.ResNest50 --anchors_mode anchor_free
    # UniformerSmall32 backbone + EfficientDetD0 header using yolor anchors
    CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --backbone uniformer.UniformerSmall32 --anchors_mode yolor
    
    # Typical YOLOXS with anchor_free anchors
    CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --det_header yolox.YOLOXS --freeze_backbone_epochs 0
    # YOLOXS with efficientdet anchors
    CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --det_header yolox.YOLOXS --anchors_mode efficientdet --freeze_backbone_epochs 0
    # CoAtNet0 backbone + YOLOX header with yolor anchors
    CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --backbone coatnet.CoAtNet0 --det_header yolox.YOLOX --anchors_mode yolor
    
    # Typical YOLOR_P6 with yolor anchors
    CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --det_header yolor.YOLOR_P6 --freeze_backbone_epochs 0
    # YOLOR_P6 with anchor_free anchors
    CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --det_header yolor.YOLOR_P6 --anchors_mode anchor_free  --freeze_backbone_epochs 0
    # ConvNeXtTiny backbone + YOLOR header with efficientdet anchors
    CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --backbone convnext.ConvNeXtTiny --det_header yolor.YOLOR --anchors_mode yolor
    

    Note: COCO training still under testing, may change parameters and default behaviors. Take the risk if would like help developing.

  • coco_eval_script.py is used for evaluating model AP / AR on COCO validation set. It has a dependency pip install pycocotools which is not in package requirements. More usage can be found in COCO Evaluation.

    # EfficientDetD0 using resize method bilinear w/o antialias
    CUDA_VISIBLE_DEVICES='1' python3 coco_eval_script.py -m efficientdet.EfficientDetD0 --resize_method bilinear --disable_antialias
    # >>>> [COCOEvalCallback] input_shape: (512, 512), pyramid_levels: [3, 7], anchors_mode: efficientdet
    
    # YOLOX using BGR input format
    CUDA_VISIBLE_DEVICES='1' python3 coco_eval_script.py -m yolox.YOLOXTiny --use_bgr_input --nms_method hard --nms_iou_or_sigma 0.65
    # >>>> [COCOEvalCallback] input_shape: (416, 416), pyramid_levels: [3, 5], anchors_mode: anchor_free
    
    # YOLOR / YOLOV7 using letterbox_pad and other tricks.
    CUDA_VISIBLE_DEVICES='1' python3 coco_eval_script.py -m yolor.YOLOR_CSP --nms_method hard --nms_iou_or_sigma 0.65 \
    --nms_max_output_size 300 --nms_topk -1 --letterbox_pad 64 --input_shape 704
    # >>>> [COCOEvalCallback] input_shape: (704, 704), pyramid_levels: [3, 5], anchors_mode: yolor
    
    # Specify h5 model
    CUDA_VISIBLE_DEVICES='1' python3 coco_eval_script.py -m checkpoints/yoloxtiny_yolor_anchor.h5
    # >>>> [COCOEvalCallback] input_shape: (416, 416), pyramid_levels: [3, 5], anchors_mode: yolor
    
  • [Experimental] Training using PyTorch backend

    import os, sys, torch
    os.environ["KECAM_BACKEND"] = "torch"
    
    from keras_cv_attention_models.yolov8 import train, yolov8
    from keras_cv_attention_models import efficientnet
    
    global_device = torch.device("cuda:0") if torch.cuda.is_available() and int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")) >= 0 else torch.device("cpu")
    # model Trainable params: 7,023,904, GFLOPs: 8.1815G
    bb = efficientnet.EfficientNetV2B0(input_shape=(3, 640, 640), num_classes=0)
    model = yolov8.YOLOV8_N(backbone=bb, classifier_activation=None, pretrained=None).to(global_device)  # Note: classifier_activation=None
    # model = yolov8.YOLOV8_N(input_shape=(3, None, None), classifier_activation=None, pretrained=None).to(global_device)
    ema = train.train(model, dataset_path="coco.json", initial_epoch=0)
    

    yolov8_training

CLIP training and evaluating

  • CLIP contains more detail usage.
  • custom_dataset_script.py can be used creating a tsv / json format file, which can be used as --data_name xxx.tsv for training, detail usage can be found in Custom caption dataset.
  • Train using clip_train_script.py on COCO captions Default --data_path is a testing one datasets/coco_dog_cat/captions.tsv.
    CUDA_VISIBLE_DEVICES=1 TF_XLA_FLAGS="--tf_xla_auto_jit=2" python clip_train_script.py -i 160 -b 128 \
    --text_model_pretrained None --data_path coco_captions.tsv
    
    Train Using PyTorch backend by setting KECAM_BACKEND='torch'
    KECAM_BACKEND='torch' CUDA_VISIBLE_DEVICES=1 python clip_train_script.py -i 160 -b 128 \
    --text_model_pretrained None --data_path coco_captions.tsv
    
    clip_torch_tf

Text training

  • Currently it's only a simple one modified from Github karpathy/nanoGPT.
  • Train using text_train_script.py As dataset is randomly sampled, needs to specify steps_per_epoch
    CUDA_VISIBLE_DEVICES=1 TF_XLA_FLAGS="--tf_xla_auto_jit=2" python text_train_script.py -m LLaMA2_15M \
    --steps_per_epoch 8000 --batch_size 8 --tokenizer SentencePieceTokenizer
    
    Train Using PyTorch backend by setting KECAM_BACKEND='torch'
    KECAM_BACKEND='torch' CUDA_VISIBLE_DEVICES=1 python text_train_script.py -m LLaMA2_15M \
    --steps_per_epoch 8000 --batch_size 8 --tokenizer SentencePieceTokenizer
    
    Plotting
    from keras_cv_attention_models import plot_func
    hists = ['checkpoints/text_llama2_15m_tensorflow_hist.json', 'checkpoints/text_llama2_15m_torch_hist.json']
    plot_func.plot_hists(hists, addition_plots=['val_loss', 'lr'], skip_first=3)
    
    text_tf_torch

DDPM training

  • Stable Diffusion contains more detail usage.
  • Note: Works better with PyTorch backend, Tensorflow one seems overfitted if training logger like --epochs 200, and evaluation runs ~5 times slower. [???]
  • Dataset can be a directory containing images for basic DDPM training using images only, or a recognition json file created following Custom recognition dataset, which will train using labels as instruction.
    python custom_dataset_script.py --train_images cifar10/train/ --test_images cifar10/test/
    # >>>> total_train_samples: 50000, total_test_samples: 10000, num_classes: 10
    # >>>> Saved to: cifar10.json
    
  • Train using ddpm_train_script.py on cifar10 with labels Default --data_path is builtin cifar10.
    # Set --eval_interval 50 as TF evaluation is rather slow [???]
    TF_XLA_FLAGS="--tf_xla_auto_jit=2" CUDA_VISIBLE_DEVICES=1 python ddpm_train_script.py --eval_interval 50
    
    Train Using PyTorch backend by setting KECAM_BACKEND='torch'
    KECAM_BACKEND='torch' CUDA_VISIBLE_DEVICES=1 python ddpm_train_script.py
    
    ddpm_unet_test_E100

Visualizing

  • Visualizing is for visualizing convnet filters or attention map scores.
  • make_and_apply_gradcam_heatmap is for Grad-CAM class activation visualization.
    from keras_cv_attention_models import visualizing, test_images, resnest
    mm = resnest.ResNest50()
    img = test_images.dog()
    superimposed_img, heatmap, preds = visualizing.make_and_apply_gradcam_heatmap(mm, img, layer_name="auto")
    
  • plot_attention_score_maps is model attention score maps visualization.
    from keras_cv_attention_models import visualizing, test_images, botnet
    img = test_images.dog()
    _ = visualizing.plot_attention_score_maps(botnet.BotNetSE33T(), img)
    

TFLite Conversion

  • Currently TFLite not supporting tf.image.extract_patches / tf.transpose with len(perm) > 4. Some operations could be supported in latest or tf-nightly version, like previously not supported gelu / Conv2D with groups>1 are working now. May try if encountering issue.
  • More discussion can be found Converting a trained keras CV attention model to TFLite #17. Some speed testing results can be found How to speed up inference on a quantized model #44.
  • Functions like model_surgery.convert_groups_conv2d_2_split_conv2d and model_surgery.convert_gelu_to_approximate are not needed using up-to-date TF version.
  • Not supporting VOLO / HaloNet models converting, cause they need a longer tf.transpose perm.
  • model_surgery.convert_dense_to_conv converts all Dense layer with 3D / 4D inputs to Conv1D / Conv2D, as currently TFLite xnnpack not supporting it.
    from keras_cv_attention_models import beit, model_surgery, efficientformer, mobilevit
    
    mm = efficientformer.EfficientFormerL1()
    mm = model_surgery.convert_dense_to_conv(mm)  # Convert all Dense layers
    converter = tf.lite.TFLiteConverter.from_keras_model(mm)
    open(mm.name + ".tflite", "wb").write(converter.convert())
    
    ModelDense, use_xnnpack=falseConv, use_xnnpack=falseConv, use_xnnpack=true
    MobileViT_SInference (avg) 215371 usInference (avg) 163836 usInference (avg) 163817 us
    EfficientFormerL1Inference (avg) 126829 usInference (avg) 107053 usInference (avg) 107132 us
  • model_surgery.convert_extract_patches_to_conv converts tf.image.extract_patches to a Conv2D version:
    from keras_cv_attention_models import cotnet, model_surgery
    from keras_cv_attention_models.imagenet import eval_func
    
    mm = cotnet.CotNetSE50D()
    mm = model_surgery.convert_groups_conv2d_2_split_conv2d(mm)
    # mm = model_surgery.convert_gelu_to_approximate(mm)  # Not required if using up-to-date TFLite
    mm = model_surgery.convert_extract_patches_to_conv(mm)
    converter = tf.lite.TFLiteConverter.from_keras_model(mm)
    open(mm.name + ".tflite", "wb").write(converter.convert())
    test_inputs = np.random.uniform(size=[1, *mm.input_shape[1:]])
    print(np.allclose(mm(test_inputs), eval_func.TFLiteModelInterf(mm.name + '.tflite')(test_inputs), atol=1e-7))
    # True
    
  • model_surgery.prepare_for_tflite is just a combination of above functions:
    from keras_cv_attention_models import beit, model_surgery
    
    mm = beit.BeitBasePatch16()
    mm = model_surgery.prepare_for_tflite(mm)
    converter = tf.lite.TFLiteConverter.from_keras_model(mm)
    open(mm.name + ".tflite", "wb").write(converter.convert())
    
  • Detection models including efficinetdet / yolox / yolor, model can be converted a TFLite format directly. If need DecodePredictions also included in TFLite model, need to set use_static_output=True for DecodePredictions, as TFLite requires a more static output shape. Model output shape will be fixed as [batch, max_output_size, 6]. The last dimension 6 means [bbox_top, bbox_left, bbox_bottom, bbox_right, label_index, confidence], and those valid ones are where confidence > 0.
    """ Init model """
    from keras_cv_attention_models import efficientdet
    model = efficientdet.EfficientDetD0(pretrained="coco")
    
    """ Create a model with DecodePredictions using `use_static_output=True` """
    model.decode_predictions.use_static_output = True
    # parameters like score_threshold / iou_or_sigma can be set another value if needed.
    nn = model.decode_predictions(model.outputs[0], score_threshold=0.5)
    bb = keras.models.Model(model.inputs[0], nn)
    
    """ Convert TFLite """
    converter = tf.lite.TFLiteConverter.from_keras_model(bb)
    open(bb.name + ".tflite", "wb").write(converter.convert())
    
    """ Inference test """
    from keras_cv_attention_models.imagenet import eval_func
    from keras_cv_attention_models import test_images
    
    dd = eval_func.TFLiteModelInterf(bb.name + ".tflite")
    imm = test_images.cat()
    inputs = tf.expand_dims(tf.image.resize(imm, dd.input_shape[1:-1]), 0)
    inputs = keras.applications.imagenet_utils.preprocess_input(inputs, mode='torch')
    preds = dd(inputs)[0]
    print(f"{preds.shape = }")
    # preds.shape = (100, 6)
    
    pred = preds[preds[:, -1] > 0]
    bboxes, labels, confidences = pred[:, :4], pred[:, 4], pred[:, -1]
    print(f"{bboxes = }, {labels = }, {confidences = }")
    # bboxes = array([[0.22825494, 0.47238672, 0.816262  , 0.8700745 ]], dtype=float32),
    # labels = array([16.], dtype=float32),
    # confidences = array([0.8309707], dtype=float32)
    
    """ Show result """
    from keras_cv_attention_models.coco import data
    data.show_image_with_bboxes(imm, bboxes, labels, confidences, num_classes=90)
    

Using PyTorch as backend

  • Experimental Keras PyTorch Backend.
  • Set os environment export KECAM_BACKEND='torch' to enable this PyTorch backend.
  • Currently supports most recognition and detection models except hornet*gf / nfnets / volo. For detection models, using torchvision.ops.nms while running prediction.
  • Basic model build and prediction.
    • Will load same h5 weights as TF one if available.
    • Note: input_shape will auto fit image data format. Given input_shape=(224, 224, 3) or input_shape=(3, 224, 224), will both set to (3, 224, 224) if channels_first.
    • Note: model is default set to eval mode.
    os.environ['KECAM_BACKEND'] = 'torch'
    from keras_cv_attention_models import res_mlp
    mm = res_mlp.ResMLP12()
    # >>>> Load pretrained from: ~/.keras/models/resmlp12_imagenet.h5
    print(f"{mm.input_shape = }")
    # mm.input_shape = [None, 3, 224, 224]
    
    import torch
    print(f"{isinstance(mm, torch.nn.Module) = }")
    # isinstance(mm, torch.nn.Module) = True
    
    # Run prediction
    from keras_cv_attention_models.test_images import cat
    print(mm.decode_predictions(mm(mm.preprocess_input(cat())))[0])
    # [('n02124075', 'Egyptian_cat', 0.9597896), ('n02123045', 'tabby', 0.012809471), ...]
    
  • Export typical PyTorch onnx / pth.
    import torch
    torch.onnx.export(mm, torch.randn(1, 3, *mm.input_shape[2:]), mm.name + ".onnx")
    
    # Or by export_onnx
    mm.export_onnx()
    # Exported onnx: resmlp12.onnx
    
    mm.export_pth()
    # Exported pth: resmlp12.pth
    
  • Save weights as h5. This h5 can also be loaded in typical TF backend model. Currently it's only weights without model structure supported.
    mm.save_weights("foo.h5")
    
  • Training with compile and fit Note: loss function arguments should be y_true, y_pred, while typical torch loss functions using y_pred, y_true.
    import torch
    from keras_cv_attention_models.backend import models, layers
    mm = models.Sequential([layers.Input([3, 32, 32]), layers.Conv2D(32, 3), layers.GlobalAveragePooling2D(), layers.Dense(10)])
    if torch.cuda.is_available():
        _ = mm.to("cuda")
    xx = torch.rand([64, *mm.input_shape[1:]])
    yy = torch.functional.F.one_hot(torch.randint(0, mm.output_shape[-1], size=[64]), mm.output_shape[-1]).float()
    loss = lambda y_true, y_pred: (y_true - y_pred.float()).abs().mean()
    # Here using `train_compile` instead of `compile`, as `compile` is already took by `nn.Module`.
    mm.train_compile(optimizer="AdamW", loss=loss, metrics='acc', grad_accumulate=4)
    mm.fit(xx, yy, epochs=2, batch_size=4)
    

Using keras core as backend

  • [Experimental] Set os environment export KECAM_BACKEND='keras_core' to enable this keras_core backend. Not using keras>3.0, as still not compiling with TensorFlow==2.15.0
  • keras-core has its own backends, supporting tensorflow / torch / jax, by editting ~/.keras/keras.json "backend" value.
  • Currently most recognition models except HaloNet / BotNet supported, also GPT2 / LLaMA2 supported.
  • Basic model build and prediction.
    !pip install sentencepiece  # required for llama2 tokenizer
    os.environ['KECAM_BACKEND'] = 'keras_core'
    os.environ['KERAS_BACKEND'] = 'jax'
    import kecam
    print(f"{kecam.backend.backend() = }")
    # kecam.backend.backend() = 'jax'
    mm = kecam.llama2.LLaMA2_42M()
    # >>>> Load pretrained from: ~/.keras/models/llama2_42m_tiny_stories.h5
    mm.run_prediction('As evening fell, a maiden stood at the edge of a wood. In her hands,')
    # >>>> Load tokenizer from file: ~/.keras/datasets/llama_tokenizer.model
    # <s>
    # As evening fell, a maiden stood at the edge of a wood. In her hands, she held a beautiful diamond. Everyone was surprised to see it.
    # "What is it?" one of the kids asked.
    # "It's a diamond," the maiden said.
    # ...
    

Recognition Models

AotNet

  • Keras AotNet is just a ResNet / ResNetV2 like framework, that set parameters like attn_types and se_ratio and others, which is used to apply different types attention layer. Works like byoanet / byobnet from timm.
  • Default parameters set is a typical ResNet architecture with Conv2D use_bias=False and padding like PyTorch.
from keras_cv_attention_models import aotnet
# Mixing se and outlook and halo and mhsa and cot_attention, 21M parameters.
# 50 is just a picked number that larger than the relative `num_block`.
attn_types = [None, "outlook", ["bot", "halo"] * 50, "cot"],
se_ratio = [0.25, 0, 0, 0],
model = aotnet.AotNet50V2(attn_types=attn_types, se_ratio=se_ratio, stem_type="deep", strides=1)
model.summary()

BEiT

ModelParamsFLOPsInputTop1 AccT4 Inference
BeitBasePatch16, 21k_ft1k86.53M17.61G22485.240321.226 qps
- 21k_ft1k, 38486.74M55.70G38486.808164.705 qps
BeitLargePatch16, 21k_ft1k304.43M61.68G22487.476105.998 qps
- 21k_ft1k, 384305.00M191.65G38488.38245.7307 qps
- 21k_ft1k, 512305.67M363.46G51288.58421.3097 qps

BEiTV2

ModelParamsFLOPsInputTop1 AccT4 Inference
BeitV2BasePatch1686.53M17.61G22485.5322.52 qps
- 21k_ft1k86.53M17.61G22486.5322.52 qps
BeitV2LargePatch16304.43M61.68G22487.3105.734 qps
- 21k_ft1k304.43M61.68G22488.4105.734 qps

BotNet

ModelParamsFLOPsInputTop1 AccT4 Inference
BotNet5021M5.42G224746.454 qps
BotNet10141M9.13G224448.102 qps
BotNet15256M12.84G224316.671 qps
BotNet26T12.5M3.30G25679.2461188.84 qps
BotNextECA26T10.59M2.45G25679.2701038.19 qps
BotNetSE33T13.7M3.89G25681.2610.429 qps

CAFormer

ModelParamsFLOPsInputTop1 AccT4 Inference
CAFormerS1826M4.1G22483.6399.127 qps
- 38426M13.4G38485.0181.993 qps
- 21k_ft1k26M4.1G22484.1399.127 qps
- 21k_ft1k, 38426M13.4G38485.4181.993 qps
CAFormerS3639M8.0G22484.5204.328 qps
- 38439M26.0G38485.7102.04 qps
- 21k_ft1k39M8.0G22485.8204.328 qps
- 21k_ft1k, 38439M26.0G38486.9102.04 qps
CAFormerM3656M13.2G22485.2162.257 qps
- 38456M42.0G38486.265.6188 qps
- 21k_ft1k56M13.2G22486.6162.257 qps
- 21k_ft1k, 38456M42.0G38487.565.6188 qps
CAFormerB3699M23.2G22485.5116.865 qps
- 38499M72.2G38486.450.0244 qps
- 21k_ft1k99M23.2G22487.4116.865 qps
- 21k_ft1k, 38499M72.2G38488.150.0244 qps
ModelParamsFLOPsInputTop1 AccT4 Inference
ConvFormerS1827M3.9G22483.0295.114 qps
- 38427M11.6G38484.4145.923 qps
- 21k_ft1k27M3.9G22483.7295.114 qps
- 21k_ft1k, 38427M11.6G38485.0145.923 qps
ConvFormerS3640M7.6G22484.1161.609 qps
- 38440M22.4G38485.480.2101 qps
- 21k_ft1k40M7.6G22485.4161.609 qps
- 21k_ft1k, 38440M22.4G38486.480.2101 qps
ConvFormerM3657M12.8G22484.5130.161 qps
- 38457M37.7G38485.663.9712 qps
- 21k_ft1k57M12.8G22486.1130.161 qps
- 21k_ft1k, 38457M37.7G38486.963.9712 qps
ConvFormerB36100M22.6G22484.898.0751 qps
- 384100M66.5G38485.748.5897 qps
- 21k_ft1k100M22.6G22487.098.0751 qps
- 21k_ft1k, 384100M66.5G38487.648.5897 qps

CMT

ModelParamsFLOPsInputTop1 AccT4 Inference
CMTTiny, (Self trained 105 epochs)9.5M0.65G16077.4315.566 qps
- (305 epochs)9.5M0.65G16078.94315.566 qps
- 224, (fine-tuned 69 epochs)9.5M1.32G22480.73254.87 qps
CMTTiny_torch, (1000 epochs)9.5M0.65G16079.2338.207 qps
CMTXS_torch15.2M1.58G19281.8241.288 qps
CMTSmall_torch25.1M4.09G22483.5171.109 qps
CMTBase_torch45.7M9.42G25684.5103.34 qps

CoaT

ModelParamsFLOPsInputTop1 AccT4 Inference
CoaTLiteTiny5.7M1.60G22477.5450.27 qps
CoaTLiteMini11M2.00G22479.1452.884 qps
CoaTLiteSmall20M3.97G22481.9248.846 qps
CoaTTiny5.5M4.33G22478.3152.495 qps
CoaTMini10M6.78G22481.0124.845 qps

CoAtNet

ModelParamsFLOPsInputTop1 AccT4 Inference
CoAtNet0, 160, (105 epochs)23.3M2.09G16080.48584.059 qps
CoAtNet0, (305 epochs)23.8M4.22G22482.79400.333 qps
CoAtNet025M4.6G22482.0400.333 qps
- use_dw_strides=False25M4.2G22481.6461.197 qps
CoAtNet142M8.8G22483.5206.954 qps
- use_dw_strides=False42M8.4G22483.3228.938 qps
CoAtNet275M16.6G22484.1156.359 qps
- use_dw_strides=False75M15.7G22484.1165.846 qps
CoAtNet2, 21k_ft1k75M16.6G22487.1156.359 qps
CoAtNet3168M34.7G22484.595.0703 qps
CoAtNet3, 21k_ft1k168M34.7G22487.695.0703 qps
CoAtNet3, 21k_ft1k168M203.1G51287.995.0703 qps
CoAtNet4, 21k_ft1k275M360.9G51288.174.6022 qps
CoAtNet4, 21k_ft1k, PT-RA-E150275M360.9G51288.5674.6022 qps

ConvNeXt

ModelParamsFLOPsInputTop1 AccT4 Inference
ConvNeXtTiny28M4.49G22482.1361.58 qps
- 21k_ft1k28M4.49G22482.9361.58 qps
- 21k_ft1k, 38428M13.19G38484.1182.134 qps
ConvNeXtSmall50M8.73G22483.1202.007 qps
- 21k_ft1k50M8.73G22484.6202.007 qps
- 21k_ft1k, 38450M25.67G38485.8108.125 qps
ConvNeXtBase89M15.42G22483.8160.036 qps
- 38489M45.32G38485.183.3095 qps
- 21k_ft1k89M15.42G22485.8160.036 qps
- 21k_ft1k, 38489M45.32G38486.883.3095 qps
ConvNeXtLarge198M34.46G22484.3102.27 qps
- 384198M101.28G38485.547.2086 qps
- 21k_ft1k198M34.46G22486.6102.27 qps
- 21k_ft1k, 384198M101.28G38487.547.2086 qps
ConvNeXtXlarge, 21k_ft1k350M61.06G22487.040.5776 qps
- 21k_ft1k, 384350M179.43G38487.821.797 qps
ConvNeXtXXLarge, clip846M198.09G25688.6

ConvNeXtV2

ModelParamsFLOPsInputTop1 AccT4 Inference
ConvNeXtV2Atto3.7M0.55G22476.7705.822 qps
ConvNeXtV2Femto5.2M0.78G22478.5728.02 qps
ConvNeXtV2Pico9.1M1.37G22480.3591.502 qps
ConvNeXtV2Nano15.6M2.45G22481.9471.918 qps
- 21k_ft1k15.6M2.45G22482.1471.918 qps
- 21k_ft1k, 38415.6M7.21G38483.4213.802 qps
ConvNeXtV2Tiny28.6M4.47G22483.0301.982 qps
- 21k_ft1k28.6M4.47G22483.9301.982 qps
- 21k_ft1k, 38428.6M13.1G38485.1139.578 qps
ConvNeXtV2Base89M15.4G22484.9132.575 qps
- 21k_ft1k89M15.4G22486.8132.575 qps
- 21k_ft1k, 38489M45.2G38487.766.5729 qps
ConvNeXtV2Large198M34.4G22485.886.8846 qps
- 21k_ft1k198M34.4G22487.386.8846 qps
- 21k_ft1k, 384198M101.1G38488.224.4542 qps
ConvNeXtV2Huge660M115G22486.3
- 21k_ft1k660M337.9G38488.7
- 21k_ft1k, 384660M600.8G51288.9

CoTNet

ModelParamsFLOPsInputTop1 AccT4 Inference
CotNet5022.2M3.25G22481.3324.913 qps
CotNetSE50D23.1M4.05G22481.6513.077 qps
CotNet10138.3M6.07G22482.8183.824 qps
CotNetSE101D40.9M8.44G22483.2251.487 qps
CotNetSE152D55.8M12.22G22484.0175.469 qps
CotNetSE152D55.8M24.92G32084.6175.469 qps

CSPNeXt

ModelParamsFLOPsInputTop1 AccT4 Inference
CSPNeXtTiny2.73M0.34G22469.44
CSPNeXtSmall4.89M0.66G22474.41
CSPNeXtMedium13.05M1.92G22479.27
CSPNeXtLarge27.16M4.19G22481.30
CSPNeXtXLarge48.85M7.75G22482.10

DaViT

ModelParamsFLOPsInputTop1 AccT4 Inference
DaViT_T28.36M4.56G22482.8224.563 qps
DaViT_S49.75M8.83G22484.2145.838 qps
DaViT_B87.95M15.55G22484.6114.527 qps
DaViT_L, 21k_ft1k196.8M103.2G38487.534.7015 qps
DaViT_H, 1.5B348.9M327.3G51290.212.363 qps
DaViT_G, 1.5B1.406B1.022T51290.4

DiNAT

ModelParamsFLOPsInputTop1 AccT4 Inference
DiNAT_Mini20.0M2.73G22481.883.9943 qps
DiNAT_Tiny27.9M4.34G22482.761.1902 qps
DiNAT_Small50.7M7.84G22483.841.0343 qps
DiNAT_Base89.8M13.76G22484.430.1332 qps
DiNAT_Large, 21k_ft1k200.9M30.58G22486.618.4936 qps
- 21k, (num_classes=21841)200.9M30.58G224
- 21k_ft1k, 384200.9M89.86G38487.4
DiNAT_Large_K11, 21k_ft1k201.1M92.57G38487.5

DINOv2

ModelParamsFLOPsInputTop1 AccT4 Inference
DINOv2_ViT_Small1422.83M47.23G51881.1165.271 qps
DINOv2_ViT_Base1488.12M152.6G51884.554.9769 qps
DINOv2_ViT_Large14306.4M509.6G51886.317.4108 qps
DINOv2_ViT_Giant141139.6M1790.3G51886.5

EdgeNeXt

ModelParamsFLOPsInputTop1 AccT4 Inference
EdgeNeXt_XX_Small1.33M266M25671.23902.957 qps
EdgeNeXt_X_Small2.34M547M25674.96638.346 qps
EdgeNeXt_Small5.59M1.27G25679.41536.762 qps
- usi5.59M1.27G25681.07536.762 qps
EdgeNeXt_Base18.5M3.86G25682.47383.461 qps
- usi18.5M3.86G25683.31383.461 qps
- 21k_ft1k18.5M3.86G25683.68383.461 qps

EfficientFormer

ModelParamsFLOPsInputTop1 AccT4 Inference
EfficientFormerL1, distill12.3M1.31G22479.21214.22 qps
EfficientFormerL3, distill31.4M3.95G22482.4596.705 qps
EfficientFormerL7, distill74.4M9.79G22483.3298.434 qps

EfficientFormerV2

ModelParamsFLOPsInputTop1 AccT4 Inference
EfficientFormerV2S0, distill3.60M405.2M22476.21114.38 qps
EfficientFormerV2S1, distill6.19M665.6M22479.7841.186 qps
EfficientFormerV2S2, distill12.7M1.27G22482.0573.9 qps
EfficientFormerV2L, distill26.3M2.59G22483.5377.224 qps

EfficientNet

ModelParamsFLOPsInputTop1 AccT4 Inference
EfficientNetV1B05.3M0.39G22477.61129.93 qps
- NoisyStudent5.3M0.39G22478.81129.93 qps
EfficientNetV1B17.8M0.70G24079.6758.639 qps
- NoisyStudent7.8M0.70G24081.5758.639 qps
EfficientNetV1B29.1M1.01G26080.5668.959 qps
- NoisyStudent9.1M1.01G26082.4668.959 qps
EfficientNetV1B312.2M1.86G30081.9473.607 qps
- NoisyStudent12.2M1.86G30084.1473.607 qps
EfficientNetV1B419.3M4.46G38083.3265.244 qps
- NoisyStudent19.3M4.46G38085.3265.244 qps
EfficientNetV1B530.4M10.40G45684.3146.758 qps
- NoisyStudent30.4M10.40G45686.1146.758 qps
EfficientNetV1B643.0M19.29G52884.888.0369 qps
- NoisyStudent43.0M19.29G52886.488.0369 qps
EfficientNetV1B766.3M38.13G60085.252.6616 qps
- NoisyStudent66.3M38.13G60086.952.6616 qps
EfficientNetV1L2, NoisyStudent480.3M477.98G80088.4

EfficientNetEdgeTPU

ModelParamsFLOPsInputTop1 AccT4 Inference
EfficientNetEdgeTPUSmall5.49M1.79G22478.071459.38 qps
EfficientNetEdgeTPUMedium6.90M3.01G24079.251028.95 qps
EfficientNetEdgeTPULarge10.59M7.94G30081.32527.034 qps

EfficientNetV2

ModelParamsFLOPsInputTop1 AccT4 Inference
EfficientNetV2B07.1M0.72G22478.71109.84 qps
- 21k_ft1k7.1M0.72G22477.55?1109.84 qps
EfficientNetV2B18.1M1.21G24079.8842.372 qps
- 21k_ft1k8.1M1.21G24079.03?842.372 qps
EfficientNetV2B210.1M1.71G26080.5762.865 qps
- 21k_ft1k10.1M1.71G26079.48?762.865 qps
EfficientNetV2B314.4M3.03G30082.1548.501 qps
- 21k_ft1k14.4M3.03G30082.46?548.501 qps
EfficientNetV2T13.6M3.18G28882.34496.483 qps
EfficientNetV2T_GC13.7M3.19G28882.46368.763 qps
EfficientNetV2S21.5M8.41G38483.9344.109 qps
- 21k_ft1k21.5M8.41G38484.9344.109 qps
EfficientNetV2M54.1M24.69G48085.2145.346 qps
- 21k_ft1k54.1M24.69G48086.2145.346 qps
EfficientNetV2L119.5M56.27G48085.785.6514 qps
- 21k_ft1k119.5M56.27G48086.985.6514 qps
EfficientNetV2XL, 21k_ft1k206.8M93.66G51287.255.141 qps

EfficientViT_B

ModelParamsFLOPsInputTop1 AccT4 Inference
EfficientViT_B03.41M0.12G22471.6 ?1581.76 qps
EfficientViT_B19.10M0.58G22479.4943.587 qps
- 2569.10M0.78G25679.9840.844 qps
- 2889.10M1.03G28880.4680.088 qps
EfficientViT_B224.33M1.68G22482.1583.295 qps
- 25624.33M2.25G25682.7507.187 qps
- 28824.33M2.92G28883.1419.93 qps
EfficientViT_B348.65M4.14G22483.5329.764 qps
- 25648.65M5.51G25683.8288.605 qps
- 28848.65M7.14G28884.2229.992 qps
EfficientViT_L152.65M5.28G22484.48503.068 qps
EfficientViT_L263.71M6.98G22485.05396.255 qps
- 38463.71M20.7G38485.98207.322 qps
EfficientViT_L3246.0M27.6G22485.814174.926 qps
- 384246.0M81.6G38486.40886.895 qps

EfficientViT_M

ModelParamsFLOPsInputTop1 AccT4 Inference
EfficientViT_M02.35M79.4M22463.2814.522 qps
EfficientViT_M12.98M167M22468.4948.041 qps
EfficientViT_M24.19M201M22470.8906.286 qps
EfficientViT_M36.90M263M22473.4758.086 qps
EfficientViT_M48.80M299M22474.3672.891 qps
EfficientViT_M512.47M522M22477.1577.254 qps

EVA

ModelParamsFLOPsInputTop1 AccT4 Inference
EvaLargePatch14, 21k_ft1k304.14M61.65G19688.59115.532 qps
- 21k_ft1k, 336304.53M191.55G33689.2053.3467 qps
EvaGiantPatch14, clip1012.6M267.40G22489.10
- m30m1013.0M621.45G33689.57
- m30m1014.4M1911.61G56089.80

EVA02

ModelParamsFLOPsInputTop1 AccT4 Inference
EVA02TinyPatch14, mim_in22k_ft1k5.76M4.72G33680.658320.123 qps
EVA02SmallPatch14, mim_in22k_ft1k22.13M15.57G33685.74161.774 qps
EVA02BasePatch14, mim_in22k_ft22k_ft1k87.12M107.6G44888.69234.3962 qps
EVA02LargePatch14, mim_m38m_ft22k_ft1k305.08M363.68G44890.054

FasterNet

ModelParamsFLOPsInputTop1 AccT4 Inference
FasterNetT03.9M0.34G22471.91890.83 qps
FasterNetT17.6M0.85G22476.21788.16 qps
FasterNetT215.0M1.90G22478.91353.12 qps
FasterNetS31.1M4.55G22481.3818.814 qps
FasterNetM53.5M8.72G22483.0436.383 qps
FasterNetL93.4M15.49G22483.5319.809 qps

FasterViT

ModelParamsFLOPsInputTop1 AccT4 Inference
FasterViT031.40M3.51G22482.1716.809 qps
FasterViT153.37M5.52G22483.2491.971 qps
FasterViT275.92M9.00G22484.2377.006 qps
FasterViT3159.55M18.75G22484.9216.481 qps
FasterViT4351.12M41.57G22485.471.6303 qps
FasterViT5957.52M114.08G22485.6
FasterViT6, +.21360.33M144.13G22485.8

FastViT

ModelParamsFLOPsInputTop1 AccT4 Inference
FastViT_T84.03M0.65G25676.21020.29 qps
- distill4.03M0.65G25677.21020.29 qps
- deploy=True3.99M0.64G25676.21323.14 qps
FastViT_T127.55M1.34G25679.3734.867 qps
- distill7.55M1.34G25680.3734.867 qps
- deploy=True7.50M1.33G25679.3956.332 qps
FastViT_S129.47M1.74G25679.9666.669 qps
- distill9.47M1.74G25681.1666.669 qps
- deploy=True9.42M1.74G25679.9881.429 qps
FastViT_SA1211.58M1.88G25680.9656.95 qps
- distill11.58M1.88G25681.9656.95 qps
- deploy=True11.54M1.88G25680.9833.011 qps
FastViT_SA2421.55M3.66G25682.7371.84 qps
- distill21.55M3.66G25683.4371.84 qps
- deploy=True21.49M3.66G25682.7444.055 qps
FastViT_SA3631.53M5.44G25683.6267.986 qps
- distill31.53M5.44G25684.2267.986 qps
- deploy=True31.44M5.43G25683.6325.967 qps
FastViT_MA3644.07M7.64G25683.9211.928 qps
- distill44.07M7.64G25684.6211.928 qps
- deploy=True43.96M7.63G25683.9274.559 qps

FBNetV3

ModelParamsFLOPsInputTop1 AccT4 Inference
FBNetV3B5.57M539.82M25679.15713.882 qps
FBNetV3D10.31M665.02M25679.68635.963 qps
FBNetV3G16.62M1379.30M25682.05478.835 qps

FlexiViT

ModelParamsFLOPsInputTop1 AccT4 Inference
FlexiViTSmall22.06M5.36G24082.53744.578 qps
FlexiViTBase86.59M20.33G24084.66301.948 qps
FlexiViTLarge304.47M71.09G24085.64105.187 qps

GCViT

ModelParamsFLOPsInputTop1 AccDownload
GCViT_XXTiny12.0M2.15G22479.9337.7 qps
GCViT_XTiny20.0M2.96G22482.0255.625 qps
GCViT_Tiny28.2M4.83G22483.5174.553 qps
GCViT_Tiny234.5M6.28G22483.7
GCViT_Small51.1M8.63G22484.3131.577 qps
GCViT_Small268.6M11.7G22484.8
GCViT_Base90.3M14.9G22485.0105.845 qps
GCViT_Large202.1M32.8G22485.7
- 21k_ft1k202.1M32.8G22486.6
- 21k_ft1k, 384202.9M105.1G38487.4
- 21k_ft1k, 512203.8M205.1G51287.6

GhostNet

ModelParamsFLOPsInputTop1 AccT4 Inference
GhostNet_0502.59M42.6M22466.881272.25 qps
GhostNet_1005.18M141.7M22474.161167.4 qps
GhostNet_1307.36M227.7M22475.791024.49 qps
- ssld7.36M227.7M22479.381024.49 qps

GhostNetV2

ModelParamsFLOPsInputTop1 AccT4 Inference
GhostNetV2_1006.12M168.5M22475.3797.088 qps
GhostNetV2_1308.96M271.1M22476.9722.668 qps
GhostNetV2_16012.39M400.9M22477.8572.268 qps

GMLP

ModelParamsFLOPsInputTop1 AccT4 Inference
GMLPTiny166M1.35G22472.3234.187 qps
GMLPS1620M4.44G22479.6138.363 qps
GMLPB1673M15.82G22481.677.816 qps

GPViT

ModelParamsFLOPsInputTop1 AccT4 Inference
GPViT_L19.59M6.15G22480.5210.166 qps
GPViT_L224.2M15.74G22483.4139.656 qps
GPViT_L336.7M23.54G22484.1131.284 qps
GPViT_L475.5M48.29G22484.394.1899 qps

HaloNet

ModelParamsFLOPsInputTop1 AccT4 Inference
HaloNextECA26T10.7M2.43G25679.501028.93 qps
HaloNet26T12.5M3.18G25679.131096.79 qps
HaloNetSE33T13.7M3.55G25680.99582.008 qps
HaloRegNetZB11.68M1.97G22481.042575.961 qps
HaloNet50T22.7M5.29G25681.70512.677 qps
HaloBotNet50T22.6M5.02G25682.0431.616 qps

Hiera

ModelParamsFLOPsInputTop1 AccT4 Inference
HieraTiny, mae_in1k_ft1k27.91M4.93G22482.8644.356 qps
HieraSmall, mae_in1k_ft1k35.01M6.44G22483.8491.669 qps
HieraBase, mae_in1k_ft1k51.52M9.43G22484.5351.542 qps
HieraBasePlus, mae_in1k_ft1k69.90M12.71G22485.2291.446 qps
HieraLarge, mae_in1k_ft1k213.74M40.43G22486.1111.042 qps
HieraHuge, mae_in1k_ft1k672.78M125.03G22486.9

HorNet

ModelParamsFLOPsInputTop1 AccT4 Inference
HorNetTiny22.4M4.01G22482.8222.665 qps
HorNetTinyGF23.0M3.94G22483.0
HorNetSmall49.5M8.87G22483.8166.998 qps
HorNetSmallGF50.4M8.77G22484.0
HorNetBase87.3M15.65G22484.2133.842 qps
HorNetBaseGF88.4M15.51G22484.3
HorNetLarge194.5M34.91G22486.889.8254 qps
HorNetLargeGF196.3M34.72G22487.0
HorNetLargeGF201.8M102.0G38487.7

IFormer

ModelParamsFLOPsInputTop1 AccT4 Inference
IFormerSmall19.9M4.88G22483.4254.392 qps
- 38420.9M16.29G38484.6128.98 qps
IFormerBase47.9M9.44G22484.6147.868 qps
- 38448.9M30.86G38485.777.8391 qps
IFormerLarge86.6M14.12G22484.6113.434 qps
- 38487.7M45.74G38485.860.0292 qps

InceptionNeXt

ModelParamsFLOPsInputTop1 AccT4 Inference
InceptionNeXtTiny28.05M4.21G22482.3606.527 qps
InceptionNeXtSmall49.37M8.39G22483.5329.01 qps
InceptionNeXtBase86.67M14.88G22484.0260.639 qps
- 38486.67M43.73G38485.2142.888 qps

LCNet

ModelParamsFLOPsInputTop1 AccT4 Inference
LCNet0501.88M46.02M22463.103107.89 qps
- ssld1.88M46.02M22466.103107.89 qps
LCNet0752.36M96.82M22468.823083.55 qps
LCNet1002.95M158.28M22472.102752.6 qps
- ssld2.95M158.28M22474.392752.6 qps
LCNet1504.52M338.05M22473.712250.69 qps
LCNet2006.54M585.35M22475.182028.31 qps
LCNet2509.04M900.16M22476.601686.7 qps
- ssld9.04M900.16M22480.821686.7 qps

LeViT

ModelParamsFLOPsInputTop1 AccT4 Inference
LeViT128S, distill7.8M0.31G22476.6800.53 qps
LeViT128, distill9.2M0.41G22478.6628.714 qps
LeViT192, distill11M0.66G22480.0597.299 qps
LeViT256, distill19M1.13G22481.6538.885 qps
LeViT384, distill39M2.36G22482.6460.139 qps

MaxViT

ModelParamsFLOPsInputTop1 AccT4 Inference
MaxViT_Tiny31M5.6G22483.62195.283 qps
- 38431M17.7G38485.2492.5725 qps
- 51231M33.7G51285.7252.6485 qps
MaxViT_Small69M11.7G22484.45149.286 qps
- 38469M36.1G38485.7461.5757 qps
- 51269M67.6G51286.1934.7002 qps
MaxViT_Base119M24.2G22484.9574.7351 qps
- 384119M74.2G38486.3431.9028 qps
- 512119M138.5G51286.6617.8139 qps
- imagenet21k135M24.2G22474.7351 qps
- 21k_ft1k, 384119M74.2G38488.2431.9028 qps
- 21k_ft1k, 512119M138.5G51288.3817.8139 qps
MaxViT_Large212M43.9G22485.1758.0967 qps
- 384212M133.1G38486.4024.1388 qps
- 512212M245.4G51286.7013.063 qps
- imagenet21k233M43.9G22458.0967 qps
- 21k_ft1k, 384212M133.1G38488.3224.1388 qps
- 21k_ft1k, 512212M245.4G51288.4613.063 qps
MaxViT_XLarge, imagenet21k507M97.7G224
- 21k_ft1k, 384475M293.7G38488.51
- 21k_ft1k, 512475M535.2G51288.70

MetaTransFormer

ModelParamsFLOPsInputTop1 AccT4 Inference
MetaTransformerBasePatch16, laion_2b86.86M55.73G38485.4150.731 qps
MetaTransformerLargePatch14, laion_2b304.53M191.6G33688.150.1536 qps

MLP mixer

ModelParamsFLOPsInputTop1 AccT4 Inference
MLPMixerS32, JFT19.1M1.01G22468.70488.839 qps
MLPMixerS16, JFT18.5M3.79G22473.83451.962 qps
MLPMixerB32, JFT60.3M3.25G22475.53247.629 qps
- sam60.3M3.25G22472.47247.629 qps
MLPMixerB1659.9M12.64G22476.44207.423 qps
- 21k_ft1k59.9M12.64G22480.64207.423 qps
- sam59.9M12.64G22477.36207.423 qps
- JFT59.9M12.64G22480.00207.423 qps
MLPMixerL32, JFT206.9M11.30G22480.6795.1865 qps
MLPMixerL16208.2M44.66G22471.7677.9928 qps
- 21k_ft1k208.2M44.66G22482.8977.9928 qps
- JFT208.2M44.66G22484.8277.9928 qps
- 448208.2M178.54G44883.91
- 448, JFT208.2M178.54G44886.78
MLPMixerH14, JFT432.3M121.22G22486.32
- 448, JFT432.3M484.73G44887.94

MobileNetV3

ModelParamsFLOPsInputTop1 AccT4 Inference
MobileNetV3Small0501.29M24.92M22457.892458.28 qps
MobileNetV3Small0752.04M44.35M22465.242286.44 qps
MobileNetV3Small1002.54M57.62M22467.662058.06 qps
MobileNetV3Large0753.99M156.30M22473.441643.78 qps
MobileNetV3Large1005.48M218.73M22475.771629.44 qps
- miil5.48M218.73M22477.921629.44 qps

MobileViT

ModelParamsFLOPsInputTop1 AccT4 Inference
MobileViT_XXS1.3M0.42G25669.01319.43 qps
MobileViT_XS2.3M1.05G25674.71019.57 qps
MobileViT_S5.6M2.03G25678.3790.943 qps

MobileViT_V2

ModelParamsFLOPsInputTop1 AccT4 Inference
MobileViT_V2_0501.37M0.47G25670.18718.337 qps
MobileViT_V2_0752.87M1.04G25675.56642.323 qps
MobileViT_V2_1004.90M1.83G25678.09591.217 qps
MobileViT_V2_1257.48M2.84G25679.65510.25 qps
MobileViT_V2_15010.6M4.07G25680.38466.482 qps
- 21k_ft1k10.6M4.07G25681.46466.482 qps
- 21k_ft1k, 38410.6M9.15G38482.60278.834 qps
MobileViT_V2_17514.3M5.52G25680.84412.759 qps
- 21k_ft1k14.3M5.52G25681.94412.759 qps
- 21k_ft1k, 38414.3M12.4G38482.93247.108 qps
MobileViT_V2_20018.4M7.12G25681.17394.325 qps
- 21k_ft1k18.4M7.12G25682.36394.325 qps
- 21k_ft1k, 38418.4M16.2G38483.41229.399 qps

MogaNet

ModelParamsFLOPsInputTop1 AccT4 Inference
MogaNetXtiny2.96M806M22476.5398.488 qps
MogaNetTiny5.20M1.11G22479.0362.409 qps
- 2565.20M1.45G25679.6335.372 qps
MogaNetSmall25.3M4.98G22483.4249.807 qps
MogaNetBase43.7M9.96G22484.2133.071 qps
MogaNetLarge82.5M15.96G22484.684.2045 qps

NAT

ModelParamsFLOPsInputTop1 AccT4 Inference
NAT_Mini20.0M2.73G22481.885.2324 qps
NAT_Tiny27.9M4.34G22483.262.6147 qps
NAT_Small50.7M7.84G22483.741.1545 qps
NAT_Base89.8M13.76G22484.330.8989 qps

NFNets

ModelParamsFLOPsInputTop1 AccT4 Inference
NFNetL035.07M7.13G28882.75
NFNetF071.5M12.58G25683.6
NFNetF1132.6M35.95G32084.7
NFNetF2193.8M63.24G35285.1
NFNetF3254.9M115.75G41685.7
NFNetF4316.1M216.78G51285.9
NFNetF5377.2M291.73G54486.0
NFNetF6, sam438.4M379.75G57686.5
NFNetF7499.5M481.80G608
ECA_NFNetL024.14M7.12G28882.58
ECA_NFNetL141.41M14.93G32084.01
ECA_NFNetL256.72M30.12G38484.70
ECA_NFNetL372.04M52.73G448

PVT_V2

ModelParamsFLOPsInputTop1 AccT4 Inference
PVT_V2B03.7M580.3M22470.5561.593 qps
PVT_V2B114.0M2.14G22478.7392.408 qps
PVT_V2B225.4M4.07G22482.0210.476 qps
PVT_V2B2_linear22.6M3.94G22482.1226.791 qps
PVT_V2B345.2M6.96G22483.1135.51 qps
PVT_V2B462.6M10.19G22483.697.666 qps
PVT_V2B582.0M11.81G22483.881.4798 qps

RegNetY

ModelParamsFLOPsInputTop1 AccT4 Inference
RegNetY04020.65M3.98G22482.3749.277 qps
RegNetY06430.58M6.36G22483.0436.946 qps
RegNetY08039.18M7.97G22483.17513.43 qps
RegNetY16083.59M15.92G22482.0338.046 qps
RegNetY320145.05M32.29G22482.5188.508 qps

RegNetZ

ModelParamsFLOPsInputTop1 AccT4 Inference
RegNetZB169.72M1.44G22479.868751.035 qps
RegNetZC1613.46M2.50G25682.164636.549 qps
RegNetZC16_EVO13.49M2.55G25681.9
RegNetZD3227.58M5.96G25683.422459.204 qps
RegNetZD823.37M3.95G25683.5460.021 qps
RegNetZD8_EVO23.46M4.61G25683.42
RegNetZE857.70M9.88G25684.5274.97 qps

RepViT

ModelParamsFLOPsInputTop1 AccT4 Inference
RepViT_M09, distillation5.10M0.82G22479.1
- deploy=True5.07M0.82G22479.1966.72 qps
RepViT_M10, distillation6.85M1.12G22480.31157.8 qps
- deploy=True6.81M1.12G22480.3
RepViT_M11, distillation8.29M1.35G22481.2846.682 qps
- deploy=True8.24M1.35G22481.21027.5 qps
RepViT_M15, distillation14.13M2.30G22482.5
- deploy=True14.05M2.30G22482.5
RepViT_M23, distillation23.01M4.55G22483.7
- deploy=True22.93M4.55G22483.7

ResMLP

ModelParamsFLOPsInputTop1 AccT4 Inference
ResMLP1215M3.02G22477.8928.402 qps
ResMLP2430M5.98G22480.8420.709 qps
ResMLP36116M8.94G22481.1309.513 qps
ResMLP_B24129M100.39G22483.678.3015 qps
- 21k_ft1k129M100.39G22484.478.3015 qps

ResNeSt

ModelParamsFLOPsInputTop1 AccT4 Inference
ResNest5028M5.38G22481.03534.627 qps
ResNest10149M13.33G25682.83257.074 qps
ResNest20071M35.55G32083.84118.183 qps
ResNest269111M77.42G41684.5461.167 qps

ResNetD

ModelParamsFLOPsInputTop1 AccT4 Inference
ResNet50D25.58M4.33G22480.530930.214 qps
ResNet101D44.57M8.04G22483.022502.268 qps
ResNet152D60.21M11.75G22483.680353.279 qps
ResNet200D64.69M15.25G22483.962287.73 qps

ResNetQ

ModelParamsFLOPsInputTop1 AccT4 Inference
ResNet51Q35.7M4.87G22482.36838.754 qps
ResNet61Q36.8M5.96G224730.245 qps

ResNeXt

ModelParamsFLOPsInputTop1 AccT4 Inference
ResNeXt50, (32x4d)25M4.23G22479.7681041.46 qps
- SWSL25M4.23G22482.1821041.46 qps
ResNeXt50D, (32x4d + deep)25M4.47G22479.6761010.94 qps
ResNeXt101, (32x4d)42M7.97G22480.334571.652 qps
- SWSL42M7.97G22483.230571.652 qps
ResNeXt101W, (32x8d)89M16.41G22479.308367.431 qps
- SWSL89M16.41G22484.284367.431 qps
ResNeXt101W_64, (64x4d)83.46M15.46G22482.46377.83 qps

SwinTransformerV2

ModelParamsFLOPsInputTop1 AccT4 Inference
SwinTransformerV2Tiny_ns28.3M4.69G22481.8289.205 qps
SwinTransformerV2Small_ns49.7M9.12G22483.5169.645 qps
SwinTransformerV2Tiny_window828.3M5.99G25681.8275.547 qps
SwinTransformerV2Tiny_window1628.3M6.75G25682.8217.207 qps
SwinTransformerV2Small_window849.7M11.63G25683.7157.559 qps
SwinTransformerV2Small_window1649.7M12.93G25684.1129.953 qps
SwinTransformerV2Base_window887.9M20.44G25684.2126.294 qps
SwinTransformerV2Base_window1687.9M22.17G25684.699.634 qps
SwinTransformerV2Base_window16, 21k_ft1k87.9M22.17G25686.299.634 qps
SwinTransformerV2Base_window24, 21k_ft1k87.9M55.89G38487.135.0508 qps
SwinTransformerV2Large_window16, 21k_ft1k196.7M48.03G25686.9
SwinTransformerV2Large_window24, 21k_ft1k196.7M117.1G38487.6

TinyNet

ModelParamsFLOPsInputTop1 AccT4 Inference
TinyNetE2.04M25.22M10659.862152.36 qps
TinyNetD2.34M53.35M15266.961905.56 qps
TinyNetC2.46M103.22M18471.231353.44 qps
TinyNetB3.73M206.28M18874.981196.06 qps
TinyNetA6.19M343.74M19277.65981.976 qps

TinyViT

ModelParamsFLOPsInputTop1 AccT4 Inference
TinyViT_5M, distill5.4M1.3G22479.1631.414 qps
- 21k_ft1k5.4M1.3G22480.7631.414 qps
TinyViT_11M, distill11M2.0G22481.5509.818 qps
- 21k_ft1k11M2.0G22483.2509.818 qps
TinyViT_21M, distill21M4.3G22483.1410.676 qps
- 21k_ft1k21M4.3G22484.8410.676 qps
- 21k_ft1k, 38421M13.8G38486.2199.458 qps
- 21k_ft1k, 51221M27.0G51286.5122.846 qps

UniFormer

ModelParamsFLOPsInputTop1 AccT4 Inference
UniformerSmall32, token_label22M3.66G22483.4577.334 qps
UniformerSmall6422M3.66G22482.9562.794 qps
- token_label22M3.66G22483.4562.794 qps
UniformerSmallPlus3224M4.24G22483.4546.82 qps
- token_label24M4.24G22483.9546.82 qps
UniformerSmallPlus6424M4.23G22483.4538.193 qps
- token_label24M4.23G22483.6538.193 qps
UniformerBase32, token_label50M8.32G22485.1272.485 qps
UniformerBase6450M8.31G22483.8286.963 qps
- token_label50M8.31G22484.8286.963 qps
UniformerLarge64, token_label100M19.79G22485.6154.761 qps
- token_label, 384100M63.11G38486.375.3487 qps

VanillaNet

ModelParamsFLOPsInputTop1 AccT4 Inference
VanillaNet522.33M8.46G22472.49598.964 qps
- deploy=True15.52M5.17G22472.49798.199 qps
VanillaNet656.12M10.11G22476.36465.031 qps
- deploy=True32.51M6.00G22476.36655.944 qps
VanillaNet756.67M11.84G22477.98375.479 qps
- deploy=True32.80M6.90G22477.98527.723 qps
VanillaNet865.18M13.50G22479.13341.157 qps
- deploy=True37.10M7.75G22479.13479.328 qps
VanillaNet973.68M15.17G22479.87312.815 qps
- deploy=True41.40M8.59G22479.87443.464 qps
VanillaNet1082.19M16.83G22480.57277.871 qps
- deploy=True45.69M9.43G22480.57408.082 qps
VanillaNet1190.69M18.49G22481.08267.026 qps
- deploy=True50.00M10.27G22481.08377.239 qps
VanillaNet1299.20M20.16G22481.55229.987 qps
- deploy=True54.29M11.11G22481.55358.076 qps
VanillaNet13107.7M21.82G22482.05218.256 qps
- deploy=True58.59M11.96G22482.05334.244 qps

VOLO

ModelParamsFLOPsInputTop1 AccT4 Inference
VOLO_d127M4.82G22484.2
- 38427M14.22G38485.2
VOLO_d259M9.78G22485.2
- 38459M28.84G38486.0
VOLO_d386M13.80G22485.4
- 44886M55.50G44886.3
VOLO_d4193M29.39G22485.7
- 448193M117.81G44886.8
VOLO_d5296M53.34G22486.1
- 448296M213.72G44887.0
- 512296M279.36G51287.1

WaveMLP

ModelParamsFLOPsInputTop1 AccT4 Inference
WaveMLP_T17M2.47G22480.9523.4 qps
WaveMLP_S30M4.55G22482.9203.445 qps
WaveMLP_M44M7.92G22483.3147.155 qps
WaveMLP_B63M10.26G22483.6

Detection Models

EfficientDet

ModelParamsFLOPsInputCOCO val APtest APT4 Inference
EfficientDetD03.9M2.55G51234.334.6248.009 qps
- Det-AdvProp3.9M2.55G51235.135.3248.009 qps
EfficientDetD16.6M6.13G64040.240.5133.139 qps
- Det-AdvProp6.6M6.13G64040.840.9133.139 qps
EfficientDetD28.1M11.03G76843.543.989.0523 qps
- Det-AdvProp8.1M11.03G76844.344.389.0523 qps
EfficientDetD312.0M24.95G89646.847.250.0498 qps
- Det-AdvProp12.0M24.95G89647.748.050.0498 qps
EfficientDetD420.7M55.29G102449.349.728.0086 qps
- Det-AdvProp20.7M55.29G102450.450.428.0086 qps
EfficientDetD533.7M135.62G128051.251.5
- Det-AdvProp33.7M135.62G128052.252.5
EfficientDetD651.9M225.93G128052.152.6
EfficientDetD751.9M325.34G153653.453.7
EfficientDetD7X77.0M410.87G153654.455.1
EfficientDetLite03.2M0.98G32027.526.41599.616 qps
EfficientDetLite14.2M1.97G38432.631.50369.273 qps
EfficientDetLite25.3M3.38G44836.235.06278.263 qps
EfficientDetLite38.4M7.50G51239.938.77180.871 qps
EfficientDetLite3X9.3M14.01G64044.042.64115.271 qps
EfficientDetLite415.1M20.20G64044.443.1895.4122 qps

YOLO_NAS

ModelParamsFLOPsInputCOCO val APtest APT4 Inference
YOLO_NAS_S12.88M16.96G64047.5240.087 qps
- use_reparam_conv=False12.18M15.92G64047.5345.595 qps
YOLO_NAS_M33.86M47.12G64051.55128.96 qps
- use_reparam_conv=False31.92M43.91G64051.55167.935 qps
YOLO_NAS_L44.53M64.53G64052.2298.6069 qps
- use_reparam_conv=False42.02M59.95G64052.22131.11 qps

YOLOR

ModelParamsFLOPsInputCOCO val APtest APT4 Inference
YOLOR_CSP52.9M60.25G64050.052.8118.746 qps
YOLOR_CSPX99.8M111.11G64051.554.867.9444 qps
YOLOR_P637.3M162.87G128052.555.749.3128 qps
YOLOR_W679.9M226.67G128053.6 ?56.940.2355 qps
YOLOR_E6115.9M341.62G128050.3 ?57.621.5719 qps
YOLOR_D6151.8M467.88G128050.8 ?58.216.6061 qps

YOLOV7

ModelParamsFLOPsInputCOCO val APtest APT4 Inference
YOLOV7_Tiny6.23M2.90G41633.3845.903 qps
YOLOV7_CSP37.67M53.0G64051.4137.441 qps
YOLOV7_X71.41M95.0G64053.182.0534 qps
YOLOV7_W670.49M180.1G128054.949.9841 qps
YOLOV7_E697.33M257.6G128056.031.3852 qps
YOLOV7_D6133.9M351.4G128056.626.1346 qps
YOLOV7_E6E151.9M421.7G128056.820.1331 qps

YOLOV8

ModelParamsFLOPsInputCOCO val APtest APT4 Inference
YOLOV8_N3.16M4.39G64037.3614.042 qps
YOLOV8_S11.17M14.33G64044.9349.528 qps
YOLOV8_M25.90M39.52G64050.2160.212 qps
YOLOV8_L43.69M82.65G64052.9104.452 qps
YOLOV8_X68.23M129.0G64053.966.0428 qps
YOLOV8_X697.42M522.6G128056.7 ?17.4368 qps

YOLOX

ModelParamsFLOPsInputCOCO val APtest APT4 Inference
YOLOXNano0.91M0.53G41625.8930.57 qps
YOLOXTiny5.06M3.22G41632.8745.2 qps
YOLOXS9.0M13.39G64040.540.5380.38 qps
YOLOXM25.3M36.84G64046.947.2181.084 qps
YOLOXL54.2M77.76G64049.750.1111.517 qps
YOLOXX99.1M140.87G64051.551.562.3189 qps

Language Models

GPT2

ModelParamsFLOPsvocab_sizeLAMBADA PPLT4 Inference
GPT2_Base163.04M146.42G5025735.1351.4483 qps
GPT2_Medium406.29M415.07G5025715.6021.756 qps
GPT2_Large838.36M890.28G5025710.87
GPT2_XLarge, +.21.638B1758.3G502578.63

LLaMA2

ModelParamsFLOPsvocab_sizeVal lossT4 Inference
LLaMA2_15M24.41M4.06G320001.072
LLaMA2_42M58.17M50.7G320000.847
LLaMA2_110M134.1M130.2G320000.760
LLaMA2_1B1.10B2.50T32003
LLaMA2_7B6.74B14.54T32000

Stable Diffusion

ModelParamsFLOPsInputDownload
ViTTextLargePatch14123.1M6.67G[None, 77]vit_text_large_patch14_clip.h5
Encoder34.16M559.6G[None, 512, 512, 3]encoder_v1_5.h5
UNet859.5M404.4G[None, 64, 64, 4]unet_v1_5.h5
Decoder49.49M1259.5G[None, 64, 64, 4]decoder_v1_5.h5

Segment Anything

ModelParamsFLOPsInputCOCO val mIoUT4 Inference
MobileSAM5.74M39.4G102472.8
TinySAM5.74M39.4G1024
EfficientViT_SAM_L030.73M35.4G51274.45

Licenses

  • This part is copied and modified according to Github rwightman/pytorch-image-models.
  • Code. The code here is licensed MIT. It is your responsibility to ensure you comply with licenses here and conditions of any dependent licenses. Where applicable, I've linked the sources/references for various components in docstrings. If you think I've missed anything please create an issue. So far all of the pretrained weights available here are pretrained on ImageNet and COCO with a select few that have some additional pretraining.
  • ImageNet Pretrained Weights. ImageNet was released for non-commercial research purposes only (https://image-net.org/download). It's not clear what the implications of that are for the use of pretrained weights from that dataset. Any models I have trained with ImageNet are done for research purposes and one should assume that the original dataset license applies to the weights. It's best to seek legal advice if you intend to use the pretrained weights in a commercial product.
  • COCO Pretrained Weights. Should follow cocodataset termsofuse. The annotations in COCO dataset belong to the COCO Consortium and are licensed under a Creative Commons Attribution 4.0 License. The COCO Consortium does not own the copyright of the images. Use of the images must abide by the Flickr Terms of Use. The users of the images accept full responsibility for the use of the dataset, including but not limited to the use of any copies of copyrighted images that they may create from the dataset.
  • Pretrained on more than ImageNet and COCO. Several weights included or references here were pretrained with proprietary datasets that I do not have access to. These include the Facebook WSL, SSL, SWSL ResNe(Xt) and the Google Noisy Student EfficientNet models. The Facebook models have an explicit non-commercial license (CC-BY-NC 4.0, https://github.com/facebookresearch/semi-supervised-ImageNet1K-models, https://github.com/facebookresearch/WSL-Images). The Google models do not appear to have any restriction beyond the Apache 2.0 license (and ImageNet concerns). In either case, you should contact Facebook or Google with any questions.

Citing

  • BibTeX
    @misc{leondgarse,
      author = {Leondgarse},
      title = {Keras CV Attention Models},
      year = {2022},
      publisher = {GitHub},
      journal = {GitHub repository},
      doi = {10.5281/zenodo.6506947},
      howpublished = {\url{https://github.com/leondgarse/keras_cv_attention_models}}
    }
    
  • Latest DOI: DOI

Keywords

FAQs


Did you know?

Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.

Install

Related posts

SocketSocket SOC 2 Logo

Product

  • Package Alerts
  • Integrations
  • Docs
  • Pricing
  • FAQ
  • Roadmap

Stay in touch

Get open source security insights delivered straight into your inbox.


  • Terms
  • Privacy
  • Security

Made with ⚡️ by Socket Inc