- Currently recommended TF version is
tensorflow==2.11.1
. Expecially for training or TFLite conversion. - Default import will not specific these while using them in READMEs.
import os
import sys
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tensorflow import keras
- Install as pip package.
kecam
is a short alias name of this package. Note: the pip package kecam
doesn't set any backend requirement, make sure either Tensorflow or PyTorch installed before hand. For PyTorch backend usage, refer Keras PyTorch Backend.
pip install -U kecam
pip install -U keras-cv-attention-models
pip install -U git+https://github.com/leondgarse/keras_cv_attention_models
Refer to each sub directory for detail usage. - Basic model prediction
from keras_cv_attention_models import volo
mm = volo.VOLO_d1(pretrained="imagenet")
""" Run predict """
import tensorflow as tf
from tensorflow import keras
from keras_cv_attention_models.test_images import cat
img = cat()
imm = keras.applications.imagenet_utils.preprocess_input(img, mode='torch')
pred = mm(tf.expand_dims(tf.image.resize(imm, mm.input_shape[1:3]), 0)).numpy()
pred = tf.nn.softmax(pred).numpy()
print(keras.applications.imagenet_utils.decode_predictions(pred)[0])
Or just use model preset preprocess_input
and decode_predictions
from keras_cv_attention_models import coatnet
mm = coatnet.CoAtNet0()
from keras_cv_attention_models.test_images import cat
preds = mm(mm.preprocess_input(cat()))
print(mm.decode_predictions(preds))
The preset preprocess_input
and decode_predictions
also compatible with PyTorch backend.
os.environ['KECAM_BACKEND'] = 'torch'
from keras_cv_attention_models import caformer
mm = caformer.CAFormerS18()
from keras_cv_attention_models.test_images import cat
preds = mm(mm.preprocess_input(cat()))
print(preds.shape)
print(mm.decode_predictions(preds))
num_classes=0
set for excluding model top GlobalAveragePooling2D + Dense
layers.
from keras_cv_attention_models import resnest
mm = resnest.ResNest50(num_classes=0)
print(mm.output_shape)
num_classes={custom output classes}
others than 1000
or 0
will just skip loading the header Dense layer weights. As model.load_weights(weight_file, by_name=True, skip_mismatch=True)
is used for loading weights.
from keras_cv_attention_models import swin_transformer_v2
mm = swin_transformer_v2.SwinTransformerV2Tiny_window8(num_classes=64)
- Reload own model weights by set
pretrained="xxx.h5"
. Better than calling model.load_weights
directly, if reloading model with different input_shape
and with weights shape not matching.
import os
from keras_cv_attention_models import coatnet
pretrained = os.path.expanduser('~/.keras/models/coatnet0_224_imagenet.h5')
mm = coatnet.CoAtNet1(input_shape=(384, 384, 3), pretrained=pretrained)
- Alias name
kecam
can be used instead of keras_cv_attention_models
. It's __init__.py
only with from keras_cv_attention_models import *
.
import kecam
mm = kecam.yolor.YOLOR_CSP()
imm = kecam.test_images.dog_cat()
preds = mm(mm.preprocess_input(imm))
bboxs, lables, confidences = mm.decode_predictions(preds)[0]
kecam.coco.show_image_with_bboxes(imm, bboxs, lables, confidences)
- Calculate flops method from TF 2.0 Feature: Flops calculation #32809. For PyTorch backend, needs
thop
pip install thop
.
from keras_cv_attention_models import coatnet, resnest, model_surgery
model_surgery.get_flops(coatnet.CoAtNet0())
model_surgery.get_flops(resnest.ResNest50())
- [Deprecated]
tensorflow_addons
is not imported by default. While reloading model depending on GroupNormalization
like MobileViTV2
from h5
directly, needs to import tensorflow_addons
manually first.
import tensorflow_addons as tfa
model_path = os.path.expanduser('~/.keras/models/mobilevit_v2_050_256_imagenet.h5')
mm = keras.models.load_model(model_path)
- Export TF model to onnx. Needs
tf2onnx
for TF, pip install onnx tf2onnx onnxsim onnxruntime
. For using PyTorch backend, exporting onnx is supported by PyTorch.
from keras_cv_attention_models import volo, nat, model_surgery
mm = nat.DiNAT_Small(pretrained=True)
model_surgery.export_onnx(mm, fuse_conv_bn=True, batch_size=1, simplify=True)
from keras_cv_attention_models.imagenet import eval_func
aa = eval_func.ONNXModelInterf(mm.name + '.onnx')
inputs = np.random.uniform(size=[1, *mm.input_shape[1:]]).astype('float32')
print(f"{np.allclose(aa(inputs), mm(inputs), atol=1e-5) = }")
- Model summary
model_summary.csv
contains gathered model info.
params
for model params count in M
flops
for FLOPs in G
input
for model input shapeacc_metrics
means Imagenet Top1 Accuracy
for recognition models, COCO val AP
for detection modelsinference_qps
for T4 inference query per second
with batch_size=1 + trtexec
extra
means if any extra training info.
from keras_cv_attention_models import plot_func
plot_series = [
"efficientnetv2", 'tinynet', 'lcnet', 'mobilenetv3', 'fasternet', 'fastervit', 'ghostnet',
'inceptionnext', 'efficientvit_b', 'mobilevit', 'convnextv2', 'efficientvit_m', 'hiera',
]
plot_func.plot_model_summary(
plot_series, model_table="model_summary.csv", log_scale_x=True, allow_extras=['mae_in1k_ft1k']
)

- Code format is using
line-length=160
:
find ./* -name "*.py" | grep -v __init__ | xargs -I {} black -l 160 {}
-
Currently still under testing.
-
COCO contains more detail usage.
-
custom_dataset_script.py
can be used creating a json
format file, which can be used as --data_name xxx.json
for training, detail usage can be found in Custom detection dataset.
-
Default parameters for coco_train_script.py
is EfficientDetD0
with input_shape=(256, 256, 3), batch_size=64, mosaic_mix_prob=0.5, freeze_backbone_epochs=32, total_epochs=105
. Technically, it's any pyramid structure backbone
+ EfficientDet / YOLOX header / YOLOR header
+ anchor_free / yolor / efficientdet anchors
combination supported.
-
Currently 4 types anchors supported, parameter anchors_mode
controls which anchor to use, value in ["efficientdet", "anchor_free", "yolor", "yolov8"]
. Default None
for det_header
presets.
-
NOTE: YOLOV8
has a default regression_len=64
for bbox output length. Typically it's 4
for other detection models, for yolov8 it's reg_max=16 -> regression_len = 16 * 4 == 64
.
anchors_mode | use_object_scores | num_anchors | anchor_scale | aspect_ratios | num_scales | grid_zero_start |
---|
efficientdet | False | 9 | 4 | [1, 2, 0.5] | 3 | False |
anchor_free | True | 1 | 1 | [1] | 1 | True |
yolor | True | 3 | None | presets | None | offset=0.5 |
yolov8 | False | 1 | 1 | [1] | 1 | False |
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py -i 512 -p adamw --freeze_backbone_epochs 16 --lr_decay_steps 50
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --backbone efficientnet.EfficientNetV2B0 --det_header efficientdet.EfficientDetD0
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --backbone resnest.ResNest50 --anchors_mode anchor_free
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --backbone uniformer.UniformerSmall32 --anchors_mode yolor
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --det_header yolox.YOLOXS --freeze_backbone_epochs 0
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --det_header yolox.YOLOXS --anchors_mode efficientdet --freeze_backbone_epochs 0
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --backbone coatnet.CoAtNet0 --det_header yolox.YOLOX --anchors_mode yolor
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --det_header yolor.YOLOR_P6 --freeze_backbone_epochs 0
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --det_header yolor.YOLOR_P6 --anchors_mode anchor_free --freeze_backbone_epochs 0
CUDA_VISIBLE_DEVICES='0' python3 coco_train_script.py --backbone convnext.ConvNeXtTiny --det_header yolor.YOLOR --anchors_mode yolor
Note: COCO training still under testing, may change parameters and default behaviors. Take the risk if would like help developing.
-
coco_eval_script.py
is used for evaluating model AP / AR on COCO validation set. It has a dependency pip install pycocotools
which is not in package requirements. More usage can be found in COCO Evaluation.
CUDA_VISIBLE_DEVICES='1' python3 coco_eval_script.py -m efficientdet.EfficientDetD0 --resize_method bilinear --disable_antialias
CUDA_VISIBLE_DEVICES='1' python3 coco_eval_script.py -m yolox.YOLOXTiny --use_bgr_input --nms_method hard --nms_iou_or_sigma 0.65
CUDA_VISIBLE_DEVICES='1' python3 coco_eval_script.py -m yolor.YOLOR_CSP --nms_method hard --nms_iou_or_sigma 0.65 \
--nms_max_output_size 300 --nms_topk -1 --letterbox_pad 64 --input_shape 704
CUDA_VISIBLE_DEVICES='1' python3 coco_eval_script.py -m checkpoints/yoloxtiny_yolor_anchor.h5
-
[Experimental] Training using PyTorch backend
import os, sys, torch
os.environ["KECAM_BACKEND"] = "torch"
from keras_cv_attention_models.yolov8 import train, yolov8
from keras_cv_attention_models import efficientnet
global_device = torch.device("cuda:0") if torch.cuda.is_available() and int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")) >= 0 else torch.device("cpu")
bb = efficientnet.EfficientNetV2B0(input_shape=(3, 640, 640), num_classes=0)
model = yolov8.YOLOV8_N(backbone=bb, classifier_activation=None, pretrained=None).to(global_device)
ema = train.train(model, dataset_path="coco.json", initial_epoch=0)

- Currently
TFLite
not supporting tf.image.extract_patches
/ tf.transpose with len(perm) > 4
. Some operations could be supported in latest or tf-nightly
version, like previously not supported gelu
/ Conv2D with groups>1
are working now. May try if encountering issue. - More discussion can be found Converting a trained keras CV attention model to TFLite #17. Some speed testing results can be found How to speed up inference on a quantized model #44.
- Functions like
model_surgery.convert_groups_conv2d_2_split_conv2d
and model_surgery.convert_gelu_to_approximate
are not needed using up-to-date TF version. - Not supporting
VOLO
/ HaloNet
models converting, cause they need a longer tf.transpose
perm
. - model_surgery.convert_dense_to_conv converts all
Dense
layer with 3D / 4D inputs to Conv1D
/ Conv2D
, as currently TFLite xnnpack not supporting it.
from keras_cv_attention_models import beit, model_surgery, efficientformer, mobilevit
mm = efficientformer.EfficientFormerL1()
mm = model_surgery.convert_dense_to_conv(mm)
converter = tf.lite.TFLiteConverter.from_keras_model(mm)
open(mm.name + ".tflite", "wb").write(converter.convert())
Model | Dense, use_xnnpack=false | Conv, use_xnnpack=false | Conv, use_xnnpack=true |
---|
MobileViT_S | Inference (avg) 215371 us | Inference (avg) 163836 us | Inference (avg) 163817 us |
EfficientFormerL1 | Inference (avg) 126829 us | Inference (avg) 107053 us | Inference (avg) 107132 us |
- model_surgery.convert_extract_patches_to_conv converts
tf.image.extract_patches
to a Conv2D
version:
from keras_cv_attention_models import cotnet, model_surgery
from keras_cv_attention_models.imagenet import eval_func
mm = cotnet.CotNetSE50D()
mm = model_surgery.convert_groups_conv2d_2_split_conv2d(mm)
mm = model_surgery.convert_extract_patches_to_conv(mm)
converter = tf.lite.TFLiteConverter.from_keras_model(mm)
open(mm.name + ".tflite", "wb").write(converter.convert())
test_inputs = np.random.uniform(size=[1, *mm.input_shape[1:]])
print(np.allclose(mm(test_inputs), eval_func.TFLiteModelInterf(mm.name + '.tflite')(test_inputs), atol=1e-7))
- model_surgery.prepare_for_tflite is just a combination of above functions:
from keras_cv_attention_models import beit, model_surgery
mm = beit.BeitBasePatch16()
mm = model_surgery.prepare_for_tflite(mm)
converter = tf.lite.TFLiteConverter.from_keras_model(mm)
open(mm.name + ".tflite", "wb").write(converter.convert())
- Detection models including
efficinetdet
/ yolox
/ yolor
, model can be converted a TFLite format directly. If need DecodePredictions also included in TFLite model, need to set use_static_output=True
for DecodePredictions
, as TFLite requires a more static output shape. Model output shape will be fixed as [batch, max_output_size, 6]
. The last dimension 6
means [bbox_top, bbox_left, bbox_bottom, bbox_right, label_index, confidence]
, and those valid ones are where confidence > 0
.
""" Init model """
from keras_cv_attention_models import efficientdet
model = efficientdet.EfficientDetD0(pretrained="coco")
""" Create a model with DecodePredictions using `use_static_output=True` """
model.decode_predictions.use_static_output = True
nn = model.decode_predictions(model.outputs[0], score_threshold=0.5)
bb = keras.models.Model(model.inputs[0], nn)
""" Convert TFLite """
converter = tf.lite.TFLiteConverter.from_keras_model(bb)
open(bb.name + ".tflite", "wb").write(converter.convert())
""" Inference test """
from keras_cv_attention_models.imagenet import eval_func
from keras_cv_attention_models import test_images
dd = eval_func.TFLiteModelInterf(bb.name + ".tflite")
imm = test_images.cat()
inputs = tf.expand_dims(tf.image.resize(imm, dd.input_shape[1:-1]), 0)
inputs = keras.applications.imagenet_utils.preprocess_input(inputs, mode='torch')
preds = dd(inputs)[0]
print(f"{preds.shape = }")
pred = preds[preds[:, -1] > 0]
bboxes, labels, confidences = pred[:, :4], pred[:, 4], pred[:, -1]
print(f"{bboxes = }, {labels = }, {confidences = }")
""" Show result """
from keras_cv_attention_models.coco import data
data.show_image_with_bboxes(imm, bboxes, labels, confidences, num_classes=90)
- Experimental Keras PyTorch Backend.
- Set os environment
export KECAM_BACKEND='torch'
to enable this PyTorch backend. - Currently supports most recognition and detection models except hornet*gf / nfnets / volo. For detection models, using
torchvision.ops.nms
while running prediction. - Basic model build and prediction.
- Will load same
h5
weights as TF one if available. - Note:
input_shape
will auto fit image data format. Given input_shape=(224, 224, 3)
or input_shape=(3, 224, 224)
, will both set to (3, 224, 224)
if channels_first
. - Note: model is default set to
eval
mode.
os.environ['KECAM_BACKEND'] = 'torch'
from keras_cv_attention_models import res_mlp
mm = res_mlp.ResMLP12()
print(f"{mm.input_shape = }")
import torch
print(f"{isinstance(mm, torch.nn.Module) = }")
from keras_cv_attention_models.test_images import cat
print(mm.decode_predictions(mm(mm.preprocess_input(cat())))[0])
- Export typical PyTorch onnx / pth.
import torch
torch.onnx.export(mm, torch.randn(1, 3, *mm.input_shape[2:]), mm.name + ".onnx")
mm.export_onnx()
mm.export_pth()
- Save weights as h5. This
h5
can also be loaded in typical TF backend model. Currently it's only weights without model structure supported.
mm.save_weights("foo.h5")
- Training with compile and fit Note: loss function arguments should be
y_true, y_pred
, while typical torch loss functions using y_pred, y_true
.
import torch
from keras_cv_attention_models.backend import models, layers
mm = models.Sequential([layers.Input([3, 32, 32]), layers.Conv2D(32, 3), layers.GlobalAveragePooling2D(), layers.Dense(10)])
if torch.cuda.is_available():
_ = mm.to("cuda")
xx = torch.rand([64, *mm.input_shape[1:]])
yy = torch.functional.F.one_hot(torch.randint(0, mm.output_shape[-1], size=[64]), mm.output_shape[-1]).float()
loss = lambda y_true, y_pred: (y_true - y_pred.float()).abs().mean()
mm.train_compile(optimizer="AdamW", loss=loss, metrics='acc', grad_accumulate=4)
mm.fit(xx, yy, epochs=2, batch_size=4)