Keras_cv_attention_models
- WARNING: currently NOT compatible with
keras 3.x
, if using tensorflow>=2.16.0
, needs to install pip install tf-keras~=$(pip show tensorflow | awk -F ': ' '/Version/{print $2}')
manually. While importing, import this package ahead of Tensorflow, or set export TF_USE_LEGACY_KERAS=1
.
- It's not recommended downloading and loading model from h5 file directly, better building model and loading weights like
import kecam; mm = kecam.models.LCNet050()
.
- coco_train_script.py for TF is still under testing...
General Usage
Basic
- Default import will not specific these while using them in READMEs.
import os
import sys
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tensorflow import keras
- Install as pip package.
kecam
is a short alias name of this package. Note: the pip package kecam
doesn't set any backend requirement, make sure either Tensorflow or PyTorch installed before hand. For PyTorch backend usage, refer Keras PyTorch Backend.
pip install -U kecam
pip install -U keras-cv-attention-models
pip install -U git+https://github.com/leondgarse/keras_cv_attention_models
Refer to each sub directory for detail usage.
- Basic model prediction
from keras_cv_attention_models import volo
mm = volo.VOLO_d1(pretrained="imagenet")
""" Run predict """
import tensorflow as tf
from tensorflow import keras
from keras_cv_attention_models.test_images import cat
img = cat()
imm = keras.applications.imagenet_utils.preprocess_input(img, mode='torch')
pred = mm(tf.expand_dims(tf.image.resize(imm, mm.input_shape[1:3]), 0)).numpy()
pred = tf.nn.softmax(pred).numpy()
print(keras.applications.imagenet_utils.decode_predictions(pred)[0])
Or just use model preset preprocess_input
and decode_predictions
from keras_cv_attention_models import coatnet
mm = coatnet.CoAtNet0()
from keras_cv_attention_models.test_images import cat
preds = mm(mm.preprocess_input(cat()))
print(mm.decode_predictions(preds))
The preset preprocess_input
and decode_predictions
also compatible with PyTorch backend.
os.environ['KECAM_BACKEND'] = 'torch'
from keras_cv_attention_models import caformer
mm = caformer.CAFormerS18()
from keras_cv_attention_models.test_images import cat
preds = mm(mm.preprocess_input(cat()))
print(preds.shape)
print(mm.decode_predictions(preds))
num_classes=0
set for excluding model top GlobalAveragePooling2D + Dense
layers.
from keras_cv_attention_models import resnest
mm = resnest.ResNest50(num_classes=0)
print(mm.output_shape)
num_classes={custom output classes}
others than 1000
or 0
will just skip loading the header Dense layer weights. As model.load_weights(weight_file, by_name=True, skip_mismatch=True)
is used for loading weights.
from keras_cv_attention_models import swin_transformer_v2
mm = swin_transformer_v2.SwinTransformerV2Tiny_window8(num_classes=64)
- Reload own model weights by set
pretrained="xxx.h5"
. Better than calling model.load_weights
directly, if reloading model with different input_shape
and with weights shape not matching.
import os
from keras_cv_attention_models import coatnet
pretrained = os.path.expanduser('~/.keras/models/coatnet0_224_imagenet.h5')
mm = coatnet.CoAtNet1(input_shape=(384, 384, 3), pretrained=pretrained)
- Alias name
kecam
can be used instead of keras_cv_attention_models
. It's __init__.py
only with from keras_cv_attention_models import *
.
import kecam
mm = kecam.yolor.YOLOR_CSP()
imm = kecam.test_images.dog_cat()
preds = mm(mm.preprocess_input(imm))
bboxs, lables, confidences = mm.decode_predictions(preds)[0]
kecam.coco.show_image_with_bboxes(imm, bboxs, lables, confidences)
- Calculate flops method from TF 2.0 Feature: Flops calculation #32809. For PyTorch backend, needs
thop
pip install thop
.
from keras_cv_attention_models import coatnet, resnest, model_surgery
model_surgery.get_flops(coatnet.CoAtNet0())
model_surgery.get_flops(resnest.ResNest50())
- [Deprecated]
tensorflow_addons
is not imported by default. While reloading model depending on GroupNormalization
like MobileViTV2
from h5
directly, needs to import tensorflow_addons
manually first.
import tensorflow_addons as tfa
model_path = os.path.expanduser('~/.keras/models/mobilevit_v2_050_256_imagenet.h5')
mm = keras.models.load_model(model_path)
- Export TF model to onnx. Needs
tf2onnx
for TF, pip install onnx tf2onnx onnxsim onnxruntime
. For using PyTorch backend, exporting onnx is supported by PyTorch.
from keras_cv_attention_models import volo, nat, model_surgery
mm = nat.DiNAT_Small(pretrained=True)
model_surgery.export_onnx(mm, fuse_conv_bn=True, batch_size=1, simplify=True)
from keras_cv_attention_models.imagenet import eval_func
aa = eval_func.ONNXModelInterf(mm.name + '.onnx')
inputs = np.random.uniform(size=[1, *mm.input_shape[1:]]).astype('float32')
print(f"{np.allclose(aa(inputs), mm(inputs), atol=1e-5) = }")
- Model summary
model_summary.csv
contains gathered model info.
params
for model params count in M
flops
for FLOPs in G
input
for model input shape
acc_metrics
means Imagenet Top1 Accuracy
for recognition models, COCO val AP
for detection models
inference_qps
for T4 inference query per second
with batch_size=1 + trtexec
extra
means if any extra training info.
from keras_cv_attention_models import plot_func
plot_series = [
"efficientnetv2", 'tinynet', 'lcnet', 'mobilenetv3', 'fasternet', 'fastervit', 'ghostnet',
'inceptionnext', 'efficientvit_b', 'mobilevit', 'convnextv2', 'efficientvit_m', 'hiera',
]
plot_func.plot_model_summary(
plot_series, model_table="model_summary.csv", log_scale_x=True, allow_extras=['mae_in1k_ft1k']
)

- Code format is using
line-length=160
:
find ./* -name "*.py" | grep -v __init__ | xargs -I {} black -l 160 {}
T4 Inference
- T4 Inference in the model tables are tested using
trtexec
on Tesla T4
with CUDA=12.0.1-1, Driver=525.60.13
. All models are exported as ONNX using PyTorch backend, using batch_szie=1
only. Note: this data is for reference only, and vary in different batch sizes or benchmark tools or platforms or implementations.
- All results are tested using colab trtexec.ipynb. Thus reproducible by any others.
os.environ["KECAM_BACKEND"] = "torch"
from keras_cv_attention_models import convnext, test_images, imagenet
mm = convnext.ConvNeXtTiny()
mm.export_onnx(simplify=True)
tt = imagenet.eval_func.ONNXModelInterf('convnext_tiny.onnx')
print(mm.decode_predictions(tt(mm.preprocess_input(test_images.cat()))))
""" Run trtexec benchmark """
!trtexec --onnx=convnext_tiny.onnx --fp16 --allowGPUFallback --useSpinWait --useCudaGraph
Layers
- attention_layers is
__init__.py
only, which imports core layers defined in model architectures. Like RelativePositionalEmbedding
from botnet
, outlook_attention
from volo
, and many other Positional Embedding Layers
/ Attention Blocks
.
from keras_cv_attention_models import attention_layers
aa = attention_layers.RelativePositionalEmbedding()
print(f"{aa(tf.ones([1, 4, 14, 16, 256])).shape = }")
Model surgery
- model_surgery including functions used to change model parameters after built.
from keras_cv_attention_models import model_surgery
mm = keras.applications.ResNet50()
mm = model_surgery.replace_ReLU(mm, target_activation='PReLU')
mm = model_surgery.convert_to_fused_conv_bn_model(mm)
ImageNet training and evaluating
- ImageNet contains more detail usage and some comparing results.
- Init Imagenet dataset using tensorflow_datasets #9.
- For custom dataset,
custom_dataset_script.py
can be used creating a json
format file, which can be used as --data_name xxx.json
for training, detail usage can be found in Custom recognition dataset.
- Another method creating custom dataset is using
tfds.load
, refer Writing custom datasets and Creating private tensorflow_datasets from tfds #48 by @Medicmind.
- Running an AWS Sagemaker estimator job using
keras_cv_attention_models
can be found in AWS Sagemaker script example by @Medicmind.
aotnet.AotNet50
default parameters set is a typical ResNet50
architecture with Conv2D use_bias=False
and padding
like PyTorch
.
- Default parameters for
train_script.py
is like A3
configuration from ResNet strikes back: An improved training procedure in timm with batch_size=256, input_shape=(160, 160)
.
CUDA_VISIBLE_DEVICES='0' TF_XLA_FLAGS="--tf_xla_auto_jit=2" python3 train_script.py --seed 0 -s aotnet50
CUDA_VISIBLE_DEVICES='1' python3 eval_script.py -m aotnet50_epoch_103_val_acc_0.7674.h5 -i 224 --central_crop 0.95

- Restore from break point by setting
--restore_path
and --initial_epoch
, and keep other parameters same. restore_path
is higher priority than model
and additional_model_kwargs
, also restore optimizer
and loss
. initial_epoch
is mainly for learning rate scheduler. If not sure where it stopped, check checkpoints/{save_name}_hist.json
.
import json
with open("checkpoints/aotnet50_hist.json", "r") as ff:
aa = json.load(ff)
len(aa['lr'])
CUDA_VISIBLE_DEVICES='0' TF_XLA_FLAGS="--tf_xla_auto_jit=2" python3 train_script.py --seed 0 -r checkpoints/aotnet50_latest.h5 -I 41
eval_script.py
is used for evaluating model accuracy. EfficientNetV2 self tested imagenet accuracy #19 just showing how different parameters affecting model accuracy.
CUDA_VISIBLE_DEVICES='1' python3 eval_script.py -m regnet.RegNetZD8
CUDA_VISIBLE_DEVICES='1' python3 eval_script.py -m timm.models.resmlp_12_224 --input_shape 224
CUDA_VISIBLE_DEVICES='1' python3 eval_script.py -m checkpoints/xxx.h5
CUDA_VISIBLE_DEVICES='1' python3 eval_script.py -m xxx.tflite
- Progressive training refer to PDF 2104.00298 EfficientNetV2: Smaller Models and Faster Training. AotNet50 A3 progressive input shapes
96 128 160
:
CUDA_VISIBLE_DEVICES='1' TF_XLA_FLAGS="--tf_xla_auto_jit=2" python3 progressive_train_script.py \
--progressive_epochs 33 66 -1 \
--progressive_input_shapes 96 128 160 \
--progressive_magnitudes 2 4 6 \
-s aotnet50_progressive_3_lr_steps_100 --seed 0

- Transfer learning with
freeze_backbone
or freeze_norm_layers
: EfficientNetV2B0 transfer learning on cifar10 testing freezing backbone #55.
- Token label train test on CIFAR10 #57. Currently not working as well as expected.
Token label
is implementation of Github zihangJiang/TokenLabeling, paper PDF 2104.10858 All Tokens Matter: Token Labeling for Training Better Vision Transformers.
COCO training and evaluating
-
Currently still under testing.
-
COCO contains more detail usage.
-
custom_dataset_script.py
can be used creating a json
format file, which can be used as --data_name xxx.json
for training, detail usage can be found in Custom detection dataset.
-
Default parameters for coco_train_script.py
is EfficientDetD0
with input_shape=(256, 256, 3), batch_size=64, mosaic_mix_prob=0.5, freeze_backbone_epochs=32, total_epochs=105
. Technically, it's any pyramid structure backbone
+ EfficientDet / YOLOX header / YOLOR header
+ anchor_free / yolor / efficientdet anchors
combination supported.
-
Currently 4 types anchors supported, parameter anchors_mode
controls which anchor to use, value in ["efficientdet", "anchor_free", "yolor", "yolov8"]
. Default None
for det_header
presets.
-
NOTE: YOLOV8
has a default regression_len=64
for bbox output length. Typically it's 4
for other detection models, for yolov8 it's reg_max=16 -> regression_len = 16 * 4 == 64
.
efficientdet | False | 9 | 4 | [1, 2, 0.5] | 3 | False |
anchor_free | True | 1 | 1 | [1] | 1 | True |
yolor | True | 3 | None | presets | None | offset=0.5 |
yolov8 | False | 1 | 1 | [1] | 1 | False |
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py -i 512 -p adamw --freeze_backbone_epochs 16 --lr_decay_steps 50
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --backbone efficientnet.EfficientNetV2B0 --det_header efficientdet.EfficientDetD0
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --backbone resnest.ResNest50 --anchors_mode anchor_free
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --backbone uniformer.UniformerSmall32 --anchors_mode yolor
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --det_header yolox.YOLOXS --freeze_backbone_epochs 0
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --det_header yolox.YOLOXS --anchors_mode efficientdet --freeze_backbone_epochs 0
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --backbone coatnet.CoAtNet0 --det_header yolox.YOLOX --anchors_mode yolor
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --det_header yolor.YOLOR_P6 --freeze_backbone_epochs 0
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --det_header yolor.YOLOR_P6 --anchors_mode anchor_free --freeze_backbone_epochs 0
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --backbone convnext.ConvNeXtTiny --det_header yolor.YOLOR --anchors_mode yolor
Note: COCO training still under testing, may change parameters and default behaviors. Take the risk if would like help developing.
-
coco_eval_script.py
is used for evaluating model AP / AR on COCO validation set. It has a dependency pip install pycocotools
which is not in package requirements. More usage can be found in COCO Evaluation.
CUDA_VISIBLE_DEVICES='1' python3 coco_eval_script.py -m efficientdet.EfficientDetD0 --resize_method bilinear --disable_antialias
CUDA_VISIBLE_DEVICES='1' python3 coco_eval_script.py -m yolox.YOLOXTiny --use_bgr_input --nms_method hard --nms_iou_or_sigma 0.65
CUDA_VISIBLE_DEVICES='1' python3 coco_eval_script.py -m yolor.YOLOR_CSP --nms_method hard --nms_iou_or_sigma 0.65 \
--nms_max_output_size 300 --nms_topk -1 --letterbox_pad 64 --input_shape 704
CUDA_VISIBLE_DEVICES='1' python3 coco_eval_script.py -m checkpoints/yoloxtiny_yolor_anchor.h5
-
[Experimental] Training using PyTorch backend
import os, sys, torch
os.environ["KECAM_BACKEND"] = "torch"
from keras_cv_attention_models.yolov8 import train, yolov8
from keras_cv_attention_models import efficientnet
global_device = torch.device("cuda:0") if torch.cuda.is_available() and int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")) >= 0 else torch.device("cpu")
bb = efficientnet.EfficientNetV2B0(input_shape=(3, 640, 640), num_classes=0)
model = yolov8.YOLOV8_N(backbone=bb, classifier_activation=None, pretrained=None).to(global_device)
ema = train.train(model, dataset_path="coco.json", initial_epoch=0)

CLIP training and evaluating
- CLIP contains more detail usage.
custom_dataset_script.py
can be used creating a tsv
/ json
format file, which can be used as --data_name xxx.tsv
for training, detail usage can be found in Custom caption dataset.
- Train using
clip_train_script.py on COCO captions
Default --data_path
is a testing one datasets/coco_dog_cat/captions.tsv
.
CUDA_VISIBLE_DEVICES=1 TF_XLA_FLAGS="--tf_xla_auto_jit=2" python clip_train_script.py -i 160 -b 128 \
--text_model_pretrained None --data_path coco_captions.tsv
Train Using PyTorch backend by setting KECAM_BACKEND='torch'
KECAM_BACKEND='torch' CUDA_VISIBLE_DEVICES=1 python clip_train_script.py -i 160 -b 128 \
--text_model_pretrained None --data_path coco_captions.tsv

Text training
- Currently it's only a simple one modified from Github karpathy/nanoGPT.
- Train using
text_train_script.py
As dataset is randomly sampled, needs to specify steps_per_epoch
CUDA_VISIBLE_DEVICES=1 TF_XLA_FLAGS="--tf_xla_auto_jit=2" python text_train_script.py -m LLaMA2_15M \
--steps_per_epoch 8000 --batch_size 8 --tokenizer SentencePieceTokenizer
Train Using PyTorch backend by setting KECAM_BACKEND='torch'
KECAM_BACKEND='torch' CUDA_VISIBLE_DEVICES=1 python text_train_script.py -m LLaMA2_15M \
--steps_per_epoch 8000 --batch_size 8 --tokenizer SentencePieceTokenizer
Plotting
from keras_cv_attention_models import plot_func
hists = ['checkpoints/text_llama2_15m_tensorflow_hist.json', 'checkpoints/text_llama2_15m_torch_hist.json']
plot_func.plot_hists(hists, addition_plots=['val_loss', 'lr'], skip_first=3)

DDPM training
- Stable Diffusion contains more detail usage.
- Note: Works better with PyTorch backend, Tensorflow one seems overfitted if training logger like
--epochs 200
, and evaluation runs ~5 times slower. [???]
- Dataset can be a directory containing images for basic DDPM training using images only, or a recognition json file created following Custom recognition dataset, which will train using labels as instruction.
python custom_dataset_script.py --train_images cifar10/train/ --test_images cifar10/test/
- Train using
ddpm_train_script.py on cifar10 with labels
Default --data_path
is builtin cifar10
.
TF_XLA_FLAGS="--tf_xla_auto_jit=2" CUDA_VISIBLE_DEVICES=1 python ddpm_train_script.py --eval_interval 50
Train Using PyTorch backend by setting KECAM_BACKEND='torch'
KECAM_BACKEND='torch' CUDA_VISIBLE_DEVICES=1 python ddpm_train_script.py

Visualizing
- Visualizing is for visualizing convnet filters or attention map scores.
- make_and_apply_gradcam_heatmap is for Grad-CAM class activation visualization.
from keras_cv_attention_models import visualizing, test_images, resnest
mm = resnest.ResNest50()
img = test_images.dog()
superimposed_img, heatmap, preds = visualizing.make_and_apply_gradcam_heatmap(mm, img, layer_name="auto")

- plot_attention_score_maps is model attention score maps visualization.
from keras_cv_attention_models import visualizing, test_images, botnet
img = test_images.dog()
_ = visualizing.plot_attention_score_maps(botnet.BotNetSE33T(), img)

TFLite Conversion
- Currently
TFLite
not supporting tf.image.extract_patches
/ tf.transpose with len(perm) > 4
. Some operations could be supported in latest or tf-nightly
version, like previously not supported gelu
/ Conv2D with groups>1
are working now. May try if encountering issue.
- More discussion can be found Converting a trained keras CV attention model to TFLite #17. Some speed testing results can be found How to speed up inference on a quantized model #44.
- Functions like
model_surgery.convert_groups_conv2d_2_split_conv2d
and model_surgery.convert_gelu_to_approximate
are not needed using up-to-date TF version.
- Not supporting
VOLO
/ HaloNet
models converting, cause they need a longer tf.transpose
perm
.
- model_surgery.convert_dense_to_conv converts all
Dense
layer with 3D / 4D inputs to Conv1D
/ Conv2D
, as currently TFLite xnnpack not supporting it.
from keras_cv_attention_models import beit, model_surgery, efficientformer, mobilevit
mm = efficientformer.EfficientFormerL1()
mm = model_surgery.convert_dense_to_conv(mm)
converter = tf.lite.TFLiteConverter.from_keras_model(mm)
open(mm.name + ".tflite", "wb").write(converter.convert())
MobileViT_S | Inference (avg) 215371 us | Inference (avg) 163836 us | Inference (avg) 163817 us |
EfficientFormerL1 | Inference (avg) 126829 us | Inference (avg) 107053 us | Inference (avg) 107132 us |
- model_surgery.convert_extract_patches_to_conv converts
tf.image.extract_patches
to a Conv2D
version:
from keras_cv_attention_models import cotnet, model_surgery
from keras_cv_attention_models.imagenet import eval_func
mm = cotnet.CotNetSE50D()
mm = model_surgery.convert_groups_conv2d_2_split_conv2d(mm)
mm = model_surgery.convert_extract_patches_to_conv(mm)
converter = tf.lite.TFLiteConverter.from_keras_model(mm)
open(mm.name + ".tflite", "wb").write(converter.convert())
test_inputs = np.random.uniform(size=[1, *mm.input_shape[1:]])
print(np.allclose(mm(test_inputs), eval_func.TFLiteModelInterf(mm.name + '.tflite')(test_inputs), atol=1e-7))
- model_surgery.prepare_for_tflite is just a combination of above functions:
from keras_cv_attention_models import beit, model_surgery
mm = beit.BeitBasePatch16()
mm = model_surgery.prepare_for_tflite(mm)
converter = tf.lite.TFLiteConverter.from_keras_model(mm)
open(mm.name + ".tflite", "wb").write(converter.convert())
- Detection models including
efficinetdet
/ yolox
/ yolor
, model can be converted a TFLite format directly. If need DecodePredictions also included in TFLite model, need to set use_static_output=True
for DecodePredictions
, as TFLite requires a more static output shape. Model output shape will be fixed as [batch, max_output_size, 6]
. The last dimension 6
means [bbox_top, bbox_left, bbox_bottom, bbox_right, label_index, confidence]
, and those valid ones are where confidence > 0
.
""" Init model """
from keras_cv_attention_models import efficientdet
model = efficientdet.EfficientDetD0(pretrained="coco")
""" Create a model with DecodePredictions using `use_static_output=True` """
model.decode_predictions.use_static_output = True
nn = model.decode_predictions(model.outputs[0], score_threshold=0.5)
bb = keras.models.Model(model.inputs[0], nn)
""" Convert TFLite """
converter = tf.lite.TFLiteConverter.from_keras_model(bb)
open(bb.name + ".tflite", "wb").write(converter.convert())
""" Inference test """
from keras_cv_attention_models.imagenet import eval_func
from keras_cv_attention_models import test_images
dd = eval_func.TFLiteModelInterf(bb.name + ".tflite")
imm = test_images.cat()
inputs = tf.expand_dims(tf.image.resize(imm, dd.input_shape[1:-1]), 0)
inputs = keras.applications.imagenet_utils.preprocess_input(inputs, mode='torch')
preds = dd(inputs)[0]
print(f"{preds.shape = }")
pred = preds[preds[:, -1] > 0]
bboxes, labels, confidences = pred[:, :4], pred[:, 4], pred[:, -1]
print(f"{bboxes = }, {labels = }, {confidences = }")
""" Show result """
from keras_cv_attention_models.coco import data
data.show_image_with_bboxes(imm, bboxes, labels, confidences, num_classes=90)
Using PyTorch as backend
- Experimental Keras PyTorch Backend.
- Set os environment
export KECAM_BACKEND='torch'
to enable this PyTorch backend.
- Currently supports most recognition and detection models except hornet*gf / nfnets / volo. For detection models, using
torchvision.ops.nms
while running prediction.
- Basic model build and prediction.
- Will load same
h5
weights as TF one if available.
- Note:
input_shape
will auto fit image data format. Given input_shape=(224, 224, 3)
or input_shape=(3, 224, 224)
, will both set to (3, 224, 224)
if channels_first
.
- Note: model is default set to
eval
mode.
os.environ['KECAM_BACKEND'] = 'torch'
from keras_cv_attention_models import res_mlp
mm = res_mlp.ResMLP12()
print(f"{mm.input_shape = }")
import torch
print(f"{isinstance(mm, torch.nn.Module) = }")
from keras_cv_attention_models.test_images import cat
print(mm.decode_predictions(mm(mm.preprocess_input(cat())))[0])
- Export typical PyTorch onnx / pth.
import torch
torch.onnx.export(mm, torch.randn(1, 3, *mm.input_shape[2:]), mm.name + ".onnx")
mm.export_onnx()
mm.export_pth()
- Save weights as h5. This
h5
can also be loaded in typical TF backend model. Currently it's only weights without model structure supported.
mm.save_weights("foo.h5")
- Training with compile and fit Note: loss function arguments should be
y_true, y_pred
, while typical torch loss functions using y_pred, y_true
.
import torch
from keras_cv_attention_models.backend import models, layers
mm = models.Sequential([layers.Input([3, 32, 32]), layers.Conv2D(32, 3), layers.GlobalAveragePooling2D(), layers.Dense(10)])
if torch.cuda.is_available():
_ = mm.to("cuda")
xx = torch.rand([64, *mm.input_shape[1:]])
yy = torch.functional.F.one_hot(torch.randint(0, mm.output_shape[-1], size=[64]), mm.output_shape[-1]).float()
loss = lambda y_true, y_pred: (y_true - y_pred.float()).abs().mean()
mm.compile(optimizer="AdamW", loss=loss, metrics='acc', grad_accumulate=4)
mm.fit(xx, yy, epochs=2, batch_size=4)
Using keras core as backend
- [Experimental] Set os environment
export KECAM_BACKEND='keras_core'
to enable this keras_core
backend. Not using keras>3.0
, as still not compiling with TensorFlow==2.15.0
keras-core
has its own backends, supporting tensorflow / torch / jax, by editting ~/.keras/keras.json
"backend"
value.
- Currently most recognition models except
HaloNet
/ BotNet
supported, also GPT2
/ LLaMA2
supported.
- Basic model build and prediction.
!pip install sentencepiece
os.environ['KECAM_BACKEND'] = 'keras_core'
os.environ['KERAS_BACKEND'] = 'jax'
import kecam
print(f"{kecam.backend.backend() = }")
mm = kecam.llama2.LLaMA2_42M()
mm.run_prediction('As evening fell, a maiden stood at the edge of a wood. In her hands,')
Recognition Models
AotNet
- Keras AotNet is just a
ResNet
/ ResNetV2
like framework, that set parameters like attn_types
and se_ratio
and others, which is used to apply different types attention layer. Works like byoanet
/ byobnet
from timm
.
- Default parameters set is a typical
ResNet
architecture with Conv2D use_bias=False
and padding
like PyTorch
.
from keras_cv_attention_models import aotnet
attn_types = [None, "outlook", ["bot", "halo"] * 50, "cot"],
se_ratio = [0.25, 0, 0, 0],
model = aotnet.AotNet50V2(attn_types=attn_types, se_ratio=se_ratio, stem_type="deep", strides=1)
model.summary()
BEiT
BEiTV2
BeitV2BasePatch16 | 86.53M | 17.61G | 224 | 85.5 | 322.52 qps |
- 21k_ft1k | 86.53M | 17.61G | 224 | 86.5 | 322.52 qps |
BeitV2LargePatch16 | 304.43M | 61.68G | 224 | 87.3 | 105.734 qps |
- 21k_ft1k | 304.43M | 61.68G | 224 | 88.4 | 105.734 qps |
BotNet
BotNet50 | 21M | 5.42G | 224 | | 746.454 qps |
BotNet101 | 41M | 9.13G | 224 | | 448.102 qps |
BotNet152 | 56M | 12.84G | 224 | | 316.671 qps |
BotNet26T | 12.5M | 3.30G | 256 | 79.246 | 1188.84 qps |
BotNextECA26T | 10.59M | 2.45G | 256 | 79.270 | 1038.19 qps |
BotNetSE33T | 13.7M | 3.89G | 256 | 81.2 | 610.429 qps |
CAFormer
CMT
CoaT
CoAtNet
CoAtNet0, 160, (105 epochs) | 23.3M | 2.09G | 160 | 80.48 | 584.059 qps |
CoAtNet0, (305 epochs) | 23.8M | 4.22G | 224 | 82.79 | 400.333 qps |
CoAtNet0 | 25M | 4.6G | 224 | 82.0 | 400.333 qps |
- use_dw_strides=False | 25M | 4.2G | 224 | 81.6 | 461.197 qps |
CoAtNet1 | 42M | 8.8G | 224 | 83.5 | 206.954 qps |
- use_dw_strides=False | 42M | 8.4G | 224 | 83.3 | 228.938 qps |
CoAtNet2 | 75M | 16.6G | 224 | 84.1 | 156.359 qps |
- use_dw_strides=False | 75M | 15.7G | 224 | 84.1 | 165.846 qps |
CoAtNet2, 21k_ft1k | 75M | 16.6G | 224 | 87.1 | 156.359 qps |
CoAtNet3 | 168M | 34.7G | 224 | 84.5 | 95.0703 qps |
CoAtNet3, 21k_ft1k | 168M | 34.7G | 224 | 87.6 | 95.0703 qps |
CoAtNet3, 21k_ft1k | 168M | 203.1G | 512 | 87.9 | 95.0703 qps |
CoAtNet4, 21k_ft1k | 275M | 360.9G | 512 | 88.1 | 74.6022 qps |
CoAtNet4, 21k_ft1k, PT-RA-E150 | 275M | 360.9G | 512 | 88.56 | 74.6022 qps |
ConvNeXt
ConvNeXtV2
CoTNet
CSPNeXt
DaViT
DaViT_T | 28.36M | 4.56G | 224 | 82.8 | 224.563 qps |
DaViT_S | 49.75M | 8.83G | 224 | 84.2 | 145.838 qps |
DaViT_B | 87.95M | 15.55G | 224 | 84.6 | 114.527 qps |
DaViT_L, 21k_ft1k | 196.8M | 103.2G | 384 | 87.5 | 34.7015 qps |
DaViT_H, 1.5B | 348.9M | 327.3G | 512 | 90.2 | 12.363 qps |
DaViT_G, 1.5B | 1.406B | 1.022T | 512 | 90.4 | |
DiNAT
DINOv2
EdgeNeXt
EfficientFormer
EfficientFormerV2
EfficientNet
EfficientNetEdgeTPU
EfficientNetV2
EfficientViT_B
EfficientViT_M
EVA
EVA02
FasterNet
FasterViT
FastViT
FastViT_T8 | 4.03M | 0.65G | 256 | 76.2 | 1020.29 qps |
- distill | 4.03M | 0.65G | 256 | 77.2 | 1020.29 qps |
- deploy=True | 3.99M | 0.64G | 256 | 76.2 | 1323.14 qps |
FastViT_T12 | 7.55M | 1.34G | 256 | 79.3 | 734.867 qps |
- distill | 7.55M | 1.34G | 256 | 80.3 | 734.867 qps |
- deploy=True | 7.50M | 1.33G | 256 | 79.3 | 956.332 qps |
FastViT_S12 | 9.47M | 1.74G | 256 | 79.9 | 666.669 qps |
- distill | 9.47M | 1.74G | 256 | 81.1 | 666.669 qps |
- deploy=True | 9.42M | 1.74G | 256 | 79.9 | 881.429 qps |
FastViT_SA12 | 11.58M | 1.88G | 256 | 80.9 | 656.95 qps |
- distill | 11.58M | 1.88G | 256 | 81.9 | 656.95 qps |
- deploy=True | 11.54M | 1.88G | 256 | 80.9 | 833.011 qps |
FastViT_SA24 | 21.55M | 3.66G | 256 | 82.7 | 371.84 qps |
- distill | 21.55M | 3.66G | 256 | 83.4 | 371.84 qps |
- deploy=True | 21.49M | 3.66G | 256 | 82.7 | 444.055 qps |
FastViT_SA36 | 31.53M | 5.44G | 256 | 83.6 | 267.986 qps |
- distill | 31.53M | 5.44G | 256 | 84.2 | 267.986 qps |
- deploy=True | 31.44M | 5.43G | 256 | 83.6 | 325.967 qps |
FastViT_MA36 | 44.07M | 7.64G | 256 | 83.9 | 211.928 qps |
- distill | 44.07M | 7.64G | 256 | 84.6 | 211.928 qps |
- deploy=True | 43.96M | 7.63G | 256 | 83.9 | 274.559 qps |
FBNetV3
FBNetV3B | 5.57M | 539.82M | 256 | 79.15 | 713.882 qps |
FBNetV3D | 10.31M | 665.02M | 256 | 79.68 | 635.963 qps |
FBNetV3G | 16.62M | 1379.30M | 256 | 82.05 | 478.835 qps |
FlexiViT
GCViT
GhostNet
GhostNetV2
GMLP
GMLPTiny16 | 6M | 1.35G | 224 | 72.3 | 234.187 qps |
GMLPS16 | 20M | 4.44G | 224 | 79.6 | 138.363 qps |
GMLPB16 | 73M | 15.82G | 224 | 81.6 | 77.816 qps |
GPViT
HaloNet
Hiera
HorNet
IFormer
IFormerSmall | 19.9M | 4.88G | 224 | 83.4 | 254.392 qps |
- 384 | 20.9M | 16.29G | 384 | 84.6 | 128.98 qps |
IFormerBase | 47.9M | 9.44G | 224 | 84.6 | 147.868 qps |
- 384 | 48.9M | 30.86G | 384 | 85.7 | 77.8391 qps |
IFormerLarge | 86.6M | 14.12G | 224 | 84.6 | 113.434 qps |
- 384 | 87.7M | 45.74G | 384 | 85.8 | 60.0292 qps |
InceptionNeXt
LCNet
LCNet050 | 1.88M | 46.02M | 224 | 63.10 | 3107.89 qps |
- ssld | 1.88M | 46.02M | 224 | 66.10 | 3107.89 qps |
LCNet075 | 2.36M | 96.82M | 224 | 68.82 | 3083.55 qps |
LCNet100 | 2.95M | 158.28M | 224 | 72.10 | 2752.6 qps |
- ssld | 2.95M | 158.28M | 224 | 74.39 | 2752.6 qps |
LCNet150 | 4.52M | 338.05M | 224 | 73.71 | 2250.69 qps |
LCNet200 | 6.54M | 585.35M | 224 | 75.18 | 2028.31 qps |
LCNet250 | 9.04M | 900.16M | 224 | 76.60 | 1686.7 qps |
- ssld | 9.04M | 900.16M | 224 | 80.82 | 1686.7 qps |
LeViT
MaxViT
MetaTransFormer
MLP mixer
MLPMixerS32, JFT | 19.1M | 1.01G | 224 | 68.70 | 488.839 qps |
MLPMixerS16, JFT | 18.5M | 3.79G | 224 | 73.83 | 451.962 qps |
MLPMixerB32, JFT | 60.3M | 3.25G | 224 | 75.53 | 247.629 qps |
- sam | 60.3M | 3.25G | 224 | 72.47 | 247.629 qps |
MLPMixerB16 | 59.9M | 12.64G | 224 | 76.44 | 207.423 qps |
- 21k_ft1k | 59.9M | 12.64G | 224 | 80.64 | 207.423 qps |
- sam | 59.9M | 12.64G | 224 | 77.36 | 207.423 qps |
- JFT | 59.9M | 12.64G | 224 | 80.00 | 207.423 qps |
MLPMixerL32, JFT | 206.9M | 11.30G | 224 | 80.67 | 95.1865 qps |
MLPMixerL16 | 208.2M | 44.66G | 224 | 71.76 | 77.9928 qps |
- 21k_ft1k | 208.2M | 44.66G | 224 | 82.89 | 77.9928 qps |
- JFT | 208.2M | 44.66G | 224 | 84.82 | 77.9928 qps |
- 448 | 208.2M | 178.54G | 448 | 83.91 | |
- 448, JFT | 208.2M | 178.54G | 448 | 86.78 | |
MLPMixerH14, JFT | 432.3M | 121.22G | 224 | 86.32 | |
- 448, JFT | 432.3M | 484.73G | 448 | 87.94 | |
MobileNetV3
MobileViT
MobileViT_V2
MogaNet
NAT
NFNets
PVT_V2
RegNetY
RegNetZ
RepViT
ResMLP
ResNeSt
ResNetD
ResNetQ
ResNet51Q | 35.7M | 4.87G | 224 | 82.36 | 838.754 qps |
ResNet61Q | 36.8M | 5.96G | 224 | | 730.245 qps |
ResNeXt
SwinTransformerV2
TinyNet
TinyViT
UniFormer
VanillaNet
VanillaNet5 | 22.33M | 8.46G | 224 | 72.49 | 598.964 qps |
- deploy=True | 15.52M | 5.17G | 224 | 72.49 | 798.199 qps |
VanillaNet6 | 56.12M | 10.11G | 224 | 76.36 | 465.031 qps |
- deploy=True | 32.51M | 6.00G | 224 | 76.36 | 655.944 qps |
VanillaNet7 | 56.67M | 11.84G | 224 | 77.98 | 375.479 qps |
- deploy=True | 32.80M | 6.90G | 224 | 77.98 | 527.723 qps |
VanillaNet8 | 65.18M | 13.50G | 224 | 79.13 | 341.157 qps |
- deploy=True | 37.10M | 7.75G | 224 | 79.13 | 479.328 qps |
VanillaNet9 | 73.68M | 15.17G | 224 | 79.87 | 312.815 qps |
- deploy=True | 41.40M | 8.59G | 224 | 79.87 | 443.464 qps |
VanillaNet10 | 82.19M | 16.83G | 224 | 80.57 | 277.871 qps |
- deploy=True | 45.69M | 9.43G | 224 | 80.57 | 408.082 qps |
VanillaNet11 | 90.69M | 18.49G | 224 | 81.08 | 267.026 qps |
- deploy=True | 50.00M | 10.27G | 224 | 81.08 | 377.239 qps |
VanillaNet12 | 99.20M | 20.16G | 224 | 81.55 | 229.987 qps |
- deploy=True | 54.29M | 11.11G | 224 | 81.55 | 358.076 qps |
VanillaNet13 | 107.7M | 21.82G | 224 | 82.05 | 218.256 qps |
- deploy=True | 58.59M | 11.96G | 224 | 82.05 | 334.244 qps |
VOLO
WaveMLP
Detection Models
EfficientDet
YOLO_NAS
YOLOR
YOLOR_CSP | 52.9M | 60.25G | 640 | 50.0 | 52.8 | 118.746 qps |
YOLOR_CSPX | 99.8M | 111.11G | 640 | 51.5 | 54.8 | 67.9444 qps |
YOLOR_P6 | 37.3M | 162.87G | 1280 | 52.5 | 55.7 | 49.3128 qps |
YOLOR_W6 | 79.9M | 226.67G | 1280 | 53.6 ? | 56.9 | 40.2355 qps |
YOLOR_E6 | 115.9M | 341.62G | 1280 | 50.3 ? | 57.6 | 21.5719 qps |
YOLOR_D6 | 151.8M | 467.88G | 1280 | 50.8 ? | 58.2 | 16.6061 qps |
YOLOV7
YOLOV8
YOLOX
YOLOXNano | 0.91M | 0.53G | 416 | 25.8 | | 930.57 qps |
YOLOXTiny | 5.06M | 3.22G | 416 | 32.8 | | 745.2 qps |
YOLOXS | 9.0M | 13.39G | 640 | 40.5 | 40.5 | 380.38 qps |
YOLOXM | 25.3M | 36.84G | 640 | 46.9 | 47.2 | 181.084 qps |
YOLOXL | 54.2M | 77.76G | 640 | 49.7 | 50.1 | 111.517 qps |
YOLOXX | 99.1M | 140.87G | 640 | 51.5 | 51.5 | 62.3189 qps |
Language Models
GPT2
LLaMA2
Stable Diffusion
ViTTextLargePatch14 | 123.1M | 6.67G | [None, 77] | vit_text_large_patch14_clip.h5 |
Encoder | 34.16M | 559.6G | [None, 512, 512, 3] | encoder_v1_5.h5 |
UNet | 859.5M | 404.4G | [None, 64, 64, 4] | unet_v1_5.h5 |
Decoder | 49.49M | 1259.5G | [None, 64, 64, 4] | decoder_v1_5.h5 |
Segmentation Models
YOLOV8 Segmentation
Segment Anything
Licenses
- This part is copied and modified according to Github rwightman/pytorch-image-models.
- Code. The code here is licensed MIT. It is your responsibility to ensure you comply with licenses here and conditions of any dependent licenses. Where applicable, I've linked the sources/references for various components in docstrings. If you think I've missed anything please create an issue. So far all of the pretrained weights available here are pretrained on ImageNet and COCO with a select few that have some additional pretraining.
- ImageNet Pretrained Weights. ImageNet was released for non-commercial research purposes only (https://image-net.org/download). It's not clear what the implications of that are for the use of pretrained weights from that dataset. Any models I have trained with ImageNet are done for research purposes and one should assume that the original dataset license applies to the weights. It's best to seek legal advice if you intend to use the pretrained weights in a commercial product.
- COCO Pretrained Weights. Should follow cocodataset termsofuse. The annotations in COCO dataset belong to the COCO Consortium and are licensed under a Creative Commons Attribution 4.0 License. The COCO Consortium does not own the copyright of the images. Use of the images must abide by the Flickr Terms of Use. The users of the images accept full responsibility for the use of the dataset, including but not limited to the use of any copies of copyrighted images that they may create from the dataset.
- Pretrained on more than ImageNet and COCO. Several weights included or references here were pretrained with proprietary datasets that I do not have access to. These include the Facebook WSL, SSL, SWSL ResNe(Xt) and the Google Noisy Student EfficientNet models. The Facebook models have an explicit non-commercial license (CC-BY-NC 4.0, https://github.com/facebookresearch/semi-supervised-ImageNet1K-models, https://github.com/facebookresearch/WSL-Images). The Google models do not appear to have any restriction beyond the Apache 2.0 license (and ImageNet concerns). In either case, you should contact Facebook or Google with any questions.
Citing
- BibTeX
@misc{leondgarse,
author = {Leondgarse},
title = {Keras CV Attention Models},
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
doi = {10.5281/zenodo.6506947},
howpublished = {\url{https://github.com/leondgarse/keras_cv_attention_models}}
}
- Latest DOI:
