A direct translation of the official PyTorch implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" to TensorFlow 2.
The official Pytorch implementation can be found here.
This Fork is a packaged version of the original repo.
pip install git+https://github.com/johnypark/Swin-Transformer-Tensorflow@main
Still under construction. -Update readme -Make things work
Swin Transformer (the name Swin stands for Shifted window) is initially described in arxiv, which capably serves as a
general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is
computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention
computation to non-overlapping local windows while also allowing for cross-window connection.
Swin Transformer achieves strong performance on COCO object detection (58.7 box AP and 51.1 mask AP on test-dev) and
ADE20K semantic segmentation (53.5 mIoU on val), surpassing previous models by a large margin.
Swin-T:
python main.py --cfg configs/swin_tiny_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1kSwin-S:
python main.py --cfg configs/swin_small_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1kSwin-B:
python main.py --cfg configs/swin_base_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1kThe possible options for cfg and weights_type are:
| cfg | weights_type | 22K model | 1K Model |
|---|---|---|---|
| configs/swin_tiny_patch4_window7_224.yaml | imagenet_1k | - | github |
| configs/swin_small_patch4_window7_224.yaml | imagenet_1k | - | github |
| configs/swin_base_patch4_window7_224.yaml | imagenet_1k | - | github |
| configs/swin_base_patch4_window12_384.yaml | imagenet_1k | - | github |
| configs/swin_base_patch4_window7_224.yaml | imagenet_22kto1k | - | github |
| configs/swin_base_patch4_window12_384.yaml | imagenet_22kto1k | - | github |
| configs/swin_large_patch4_window7_224.yaml | imagenet_22kto1k | - | github |
| configs/swin_large_patch4_window12_384.yaml | imagenet_22kto1k | - | github |
| configs/swin_base_patch4_window7_224.yaml | imagenet_22k | github | - |
| configs/swin_base_patch4_window12_384.yaml | imagenet_22k | github | - |
| configs/swin_large_patch4_window7_224.yaml | imagenet_22k | github | - |
| configs/swin_large_patch4_window12_384.yaml | imagenet_22k | github | - |
To create a custom classification model:
import argparse
import tensorflow as tf
from config import get_config
from models.build import build_model
parser = argparse.ArgumentParser('Custom Swin Transformer')
parser.add_argument(
'--cfg',
type=str,
metavar="FILE",
help='path to config file',
default="CUSTOM_YAML_FILE_PATH"
)
parser.add_argument(
'--resume',
type=int,
help='Whether or not to resume training from pretrained weights',
choices={0, 1},
default=1,
)
parser.add_argument(
'--weights_type',
type=str,
help='Type of pretrained weight file to load including number of classes',
choices={"imagenet_1k", "imagenet_22k", "imagenet_22kto1k"},
default="imagenet_1k",
)
args = parser.parse_args()
custom_config = get_config(args, include_top=False)
swin_transformer = tf.keras.Sequential([
build_model(config=custom_config, load_pretrained=args.resume, weights_type=args.weights_type),
tf.keras.layers.Dense(CUSTOM_NUM_CLASSES)
)Model ouputs are logits, so don't forget to include softmax in training/inference!!
You can easily customize the model configs with custom YAML files. Predefined YAML files provided by Microsoft are located in the configs directory.
We provide a python script with which we convert official PyTorch weights into Tensorflow checkpoints.
$ python convert_weights.py --cfg config_file --weights the_path_to_pytorch_weights --weights_type type_of_pretrained_weights --output the_path_to_output_tf_weights- Translate model code over to TensorFlow
- Load PyTorch pretrained weights into TensorFlow model
- Write trainer code
- Reproduce results presented in paper
- Object Detection
- Reproduce training efficiency of official code in TensorFlow
@misc{liu2021swin,
title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
author={Ze Liu and Yutong Lin and Yue Cao and Han Hu and Yixuan Wei and Zheng Zhang and Stephen Lin and Baining Guo},
year={2021},
eprint={2103.14030},
archivePrefix={arXiv},
primaryClass={cs.CV}
}