Socket
Socket
Sign inDemoInstall

gcvit

Package Overview
Dependencies
4
Maintainers
1
Alerts
File Explorer

Install Socket

Detect and block malicious and high-risk dependencies

Install

    gcvit

Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer. https://github.com/awsaf49/gcvit-tf


Maintainers
1

Readme

GCViT: Global Context Vision Transformer

python tensorflow

Open In Colab Open In Kaggle

Tensorflow 2.0 Implementation of GCViT

This library implements GCViT using Tensorflow 2.0 specifically in tf.keras.Model manner to get PyTorch flavor.

Update

Paper Implementation & Explanation **

I have explained the GCViT paper in a Kaggle notebook GCViT: Global Context Vision Transformer, which also includes a detailed implementation of the model from scratch. The notebook provides a comprehensive explanation of each part of the model, with intuition.

Do check it out, especially if you are interested in learning more about GCViT or implementing it yourself. Note that this notebook has won the Kaggle ML Research Award 2022.

Model

  • Architecture:
  • Local Vs Global Attention:

Result

Official codebase had some issue which has been fixed recently (12 August 2022). Here's the result of ported weights on ImageNetV2-Test data,

ModelAcc@1Acc@5#Params
GCViT-XXTiny0.6630.87312M
GCViT-XTiny0.6850.88520M
GCViT-Tiny0.7080.89928M
GCViT-Small0.7200.90151M
GCViT-Base0.7310.90790M
GCViT-Large0.7340.913202M

Installation

pip install -U gcvit
# or
# pip install -U git+https://github.com/awsaf49/gcvit-tf

Usage

Load model using following codes,

from gcvit import GCViTTiny
model = GCViTTiny(pretrain=True)

Any input size other than 224x224,

from gcvit import GCViTTiny
model = GCViTTiny(input_shape=(512,512,3), pretrain=True, resize_query=True)

Simple code to check model's prediction,

from skimage.data import chelsea
img = tf.keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='torch') # Chelsea the cat
img = tf.image.resize(img, (224, 224))[None,] # resize & create batch
pred = model(img).numpy()
print(tf.keras.applications.imagenet_utils.decode_predictions(pred)[0])

Prediction:

[('n02124075', 'Egyptian_cat', 0.9194835),
('n02123045', 'tabby', 0.009686623), 
('n02123159', 'tiger_cat', 0.0061576385),
('n02127052', 'lynx', 0.0011503297), 
('n02883205', 'bow_tie', 0.00042479983)]

For feature extraction:

model = GCViTTiny(pretrain=True)  # when pretrain=True, num_classes must be 1000
model.reset_classifier(num_classes=0, head_act=None)
feature = model(img)
print(feature.shape)

Feature:

(None, 512)

For feature map:

model = GCViTTiny(pretrain=True)  # when pretrain=True, num_classes must be 1000
feature = model.forward_features(img)
print(feature.shape)

Feature map:

(None, 7, 7, 512)

Kaggle Models

These pre-trained models can also be loaded using Kaggle Models. Setting from_kaggle=True will enforce model to load weights from Kaggle Models without downloading, thus can be used without internet in Kaggle.

from gcvit import GCViTTiny
model = GCViTTiny(pretrain=True, from_kaggle=True)

Live-Demo

  • For live demo on Image Classification & Grad-CAM, with ImageNet weights, click powered by 🤗 Space and Gradio. here's an example,

Example

For working training example checkout these notebooks on Google Colab Open In Colab & Kaggle Open In Kaggle.

Here is grad-cam result after training on Flower Classification Dataset,

To Do

  • Convert it to multi-backend Keras 3.0
  • Segmentation Pipeline
  • Support for Kaggle Models
  • Remove tensorflow_addons
  • New updated weights have been added.
  • Working training example in Colab & Kaggle.
  • GradCAM showcase.
  • Gradio Demo.
  • Build model with tf.keras.Model.
  • Port weights from official repo.
  • Support for TPU.

Acknowledgement

Citation

@article{hatamizadeh2022global,
  title={Global Context Vision Transformers},
  author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
  journal={arXiv preprint arXiv:2206.09959},
  year={2022}
}

Keywords

FAQs


Did you know?

Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.

Install

Related posts

SocketSocket SOC 2 Logo

Product

  • Package Alerts
  • Integrations
  • Docs
  • Pricing
  • FAQ
  • Roadmap

Stay in touch

Get open source security insights delivered straight into your inbox.


  • Terms
  • Privacy
  • Security

Made with ⚡️ by Socket Inc