From: @ghostnet Reviewed-by: Signed-off-by:pull/14128/MERGE
| @@ -0,0 +1,75 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| eval. | |||
| """ | |||
| import os | |||
| import argparse | |||
| from mindspore import context | |||
| from mindspore import nn | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.common import dtype as mstype | |||
| from src.pet_dataset import create_dataset | |||
| from src.config import config_ascend, config_gpu | |||
| from src.tnt import tnt_b | |||
| parser = argparse.ArgumentParser(description='Image classification') | |||
| parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | |||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | |||
| parser.add_argument('--platform', type=str, default=None, help='run platform') | |||
| args_opt = parser.parse_args() | |||
| if __name__ == '__main__': | |||
| config_platform = None | |||
| if args_opt.platform == "Ascend": | |||
| config_platform = config_ascend | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", | |||
| device_id=device_id, save_graphs=False) | |||
| elif args_opt.platform == "GPU": | |||
| config_platform = config_gpu | |||
| context.set_context(mode=context.PYNATIVE_MODE, | |||
| device_target="GPU", save_graphs=False) | |||
| else: | |||
| raise ValueError("Unsupported platform.") | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||
| net = tnt_b(num_class=config_platform.num_classes) | |||
| if args_opt.checkpoint_path: | |||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||
| load_param_into_net(net, param_dict) | |||
| net.set_train(False) | |||
| if args_opt.platform == "Ascend": | |||
| net.to_float(mstype.float16) | |||
| for _, cell in net.cells_and_names(): | |||
| if isinstance(cell, nn.Dense): | |||
| cell.to_float(mstype.float32) | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, | |||
| do_train=False, | |||
| config=config_platform, | |||
| platform=args_opt.platform, | |||
| batch_size=config_platform.batch_size) | |||
| step_size = dataset.get_dataset_size() | |||
| model = Model(net, loss_fn=loss, metrics={'acc'}) | |||
| res = model.eval(dataset) | |||
| print("result:", res, "ckpt=", args_opt.checkpoint_path) | |||
| @@ -0,0 +1,22 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """hub config.""" | |||
| from src.tnt import tnt_b | |||
| def create_network(name, *args, **kwargs): | |||
| if name == 'TNT-B': | |||
| return tnt_b(*args, **kwargs) | |||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||
| @@ -0,0 +1,128 @@ | |||
| # Contents | |||
| - [TNT Description](#tnt-description) | |||
| - [Model Architecture](#model-architecture) | |||
| - [Dataset](#dataset) | |||
| - [Environment Requirements](#environment-requirements) | |||
| - [Script Description](#script-description) | |||
| - [Script and Sample Code](#script-and-sample-code) | |||
| - [Training Process](#training-process) | |||
| - [Evaluation Process](#evaluation-process) | |||
| - [Evaluation](#evaluation) | |||
| - [Model Description](#model-description) | |||
| - [Performance](#performance) | |||
| - [Training Performance](#evaluation-performance) | |||
| - [Inference Performance](#evaluation-performance) | |||
| - [Description of Random Situation](#description-of-random-situation) | |||
| - [ModelZoo Homepage](#modelzoo-homepage) | |||
| ## [TNT Description](#contents) | |||
| The TNT (Transformer in Transformer) network is a pure transformer model for visual recognition. TNT treats an image as a sequence of patches and treats a patch as a sequence of pixels. TNT block utilizes a outer transformer block to process the sequence of patches and an inner transformer block to process the sequence of pixels. | |||
| [Paper](https://arxiv.org/abs/2103.00112): Kai Han, An Xiao, Enhua Wu, Jianyuan Guo, Chunjing Xu, Yunhe Wang. Transformer in Transformer. preprint 2021. | |||
| ## [Model architecture](#contents) | |||
| The overall network architecture of TNT is show below: | |||
|  | |||
| ## [Dataset](#contents) | |||
| Dataset used: [Oxford-IIIT Pet](https://www.robots.ox.ac.uk/~vgg/data/pets/) | |||
| - Dataset size: 7049 colorful images in 1000 classes | |||
| - Train: 3680 images | |||
| - Test: 3369 images | |||
| - Data format: RGB images. | |||
| - Note: Data will be processed in src/dataset.py | |||
| ## [Environment Requirements](#contents) | |||
| - Hardware(Ascend/GPU) | |||
| - Prepare hardware environment with Ascend or GPU. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. | |||
| - Framework | |||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||
| - For more information, please check the resources below£º | |||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||
| ## [Script description](#contents) | |||
| ### [Script and sample code](#contents) | |||
| ```python | |||
| TNT | |||
| ├── eval.py # inference entry | |||
| ├── fig | |||
| │ └── tnt.png # the illustration of TNT network | |||
| ├── readme.md # Readme | |||
| └── src | |||
| ├── config.py # config of model and data | |||
| ├── pet_dataset.py # dataset loader | |||
| └── tnt.py # TNT network | |||
| ``` | |||
| ## [Training process](#contents) | |||
| To Be Done | |||
| ## [Eval process](#contents) | |||
| ### Usage | |||
| After installing MindSpore via the official website, you can start evaluation as follows: | |||
| ### Launch | |||
| ```bash | |||
| # infer example | |||
| GPU: python eval.py --model tnt-b --dataset_path ~/Pets/test.mindrecord --platform GPU --checkpoint_path [CHECKPOINT_PATH] | |||
| ``` | |||
| > checkpoint can be downloaded at https://www.mindspore.cn/resources/hub. | |||
| ### Result | |||
| ```bash | |||
| result: {'acc': 0.95} ckpt= ./tnt-b-pets.ckpt | |||
| ``` | |||
| ## [Model Description](#contents) | |||
| ### [Performance](#contents) | |||
| #### Evaluation Performance | |||
| ##### TNT on ImageNet2012 | |||
| | Parameters | | | | |||
| | -------------------------- | -------------------------------------- |---------------------------------- | | |||
| | Model Version | TNT-B |TNT-S| | |||
| | uploaded Date | 21/03/2021 (month/day/year) | 21/03/2021 (month/day/year) | | |||
| | MindSpore Version | 1.1 | 1.1 | | |||
| | Dataset | ImageNet2012 | ImageNet2012| | |||
| | Input size | 224x224 | 224x224| | |||
| | Parameters (M) | 86.4 | 23.8 | | |||
| | FLOPs (M) | 14.1 | 5.2 | | |||
| | Accuracy (Top1) | 82.8 | 81.3 | | |||
| ###### TNT on Oxford-IIIT Pet | |||
| | Parameters | | | | |||
| | -------------------------- | -------------------------------------- |---------------------------------- | | |||
| | Model Version | TNT-B |TNT-S| | |||
| | uploaded Date | 21/03/2021 (month/day/year) | 21/03/2021 (month/day/year) | | |||
| | MindSpore Version | 1.1 | 1.1 | | |||
| | Dataset | Oxford-IIIT Pet | Oxford-IIIT Pet| | |||
| | Input size | 384x384 | 384x384| | |||
| | Parameters (M) | 86.4 | 23.8 | | |||
| | Accuracy (Top1) | 95.0 | 94.7 | | |||
| ## [Description of Random Situation](#contents) | |||
| In dataset.py, we set the seed inside "create_dataset" function. We also use random seed in train.py. | |||
| ## [ModelZoo Homepage](#contents) | |||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||
| @@ -0,0 +1,54 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| network config setting, will be used in train.py and eval.py | |||
| """ | |||
| from easydict import EasyDict as ed | |||
| config_ascend = ed({ | |||
| "num_classes": 37, | |||
| "image_height": 384, | |||
| "image_width": 384, | |||
| "batch_size": 50, | |||
| "epoch_size": 300, | |||
| "warmup_epochs": 5, | |||
| "lr": 1e-3, | |||
| "momentum": 0.9, | |||
| "weight_decay": 0.05, | |||
| "label_smooth": 0.1, | |||
| "loss_scale": 1024, | |||
| "save_checkpoint": True, | |||
| "save_checkpoint_epochs": 1, | |||
| "keep_checkpoint_max": 200, | |||
| "save_checkpoint_path": "./checkpoint", | |||
| }) | |||
| config_gpu = ed({ | |||
| "num_classes": 37, | |||
| "image_height": 384, | |||
| "image_width": 384, | |||
| "batch_size": 50, | |||
| "epoch_size": 300, | |||
| "warmup_epochs": 5, | |||
| "lr": 1e-3, | |||
| "momentum": 0.9, | |||
| "weight_decay": 0.05, | |||
| "label_smooth": 0.1, | |||
| "loss_scale": 1024, | |||
| "save_checkpoint": True, | |||
| "save_checkpoint_epochs": 1, | |||
| "keep_checkpoint_max": 500, | |||
| "save_checkpoint_path": "./checkpoint", | |||
| }) | |||
| @@ -0,0 +1,97 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| create train or eval dataset. | |||
| """ | |||
| import os | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset.engine as de | |||
| import mindspore.dataset.transforms.py_transforms as py_transforms | |||
| import mindspore.dataset.transforms.c_transforms as c_transforms | |||
| import mindspore.dataset.vision.py_transforms as py_vision | |||
| from mindspore.dataset.vision import Inter | |||
| def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch_size=1): | |||
| """ | |||
| create a train or eval dataset | |||
| Args: | |||
| dataset_path(string): the path of dataset. | |||
| do_train(bool): whether dataset is used for train or eval. | |||
| repeat_num(int): the repeat times of dataset. Default: 1 | |||
| batch_size(int): the batch size of dataset. Default: 32 | |||
| Returns: | |||
| dataset | |||
| """ | |||
| if platform == "Ascend": | |||
| rank_size = int(os.getenv("RANK_SIZE")) | |||
| rank_id = int(os.getenv("RANK_ID")) | |||
| if rank_size == 1: | |||
| ds = de.MindDataset(dataset_path, num_parallel_workers=8, shuffle=True) | |||
| else: | |||
| ds = de.MindDataset(dataset_path, num_parallel_workers=8, shuffle=True, | |||
| num_shards=rank_size, shard_id=rank_id) | |||
| elif platform == "GPU": | |||
| if do_train: | |||
| from mindspore.communication.management import get_rank, get_group_size | |||
| ds = de.MindDataset(dataset_path, num_parallel_workers=8, shuffle=True, | |||
| num_shards=get_group_size(), shard_id=get_rank()) | |||
| else: | |||
| ds = de.MindDataset(dataset_path, num_parallel_workers=8, shuffle=False) | |||
| else: | |||
| raise ValueError("Unsupported platform.") | |||
| resize_height = config.image_height | |||
| resize_width = config.image_width | |||
| buffer_size = 1000 | |||
| # define map operations | |||
| random_resize_crop_bicubic = py_vision.RandomResizedCrop(size=(resize_height, resize_width), | |||
| scale=(0.08, 1.0), ratio=(3./4., 4./3.), | |||
| interpolation=Inter.BICUBIC) | |||
| random_horizontal_flip_op = py_vision.RandomHorizontalFlip(0.5) | |||
| color_jitter = 0.4 | |||
| adjust_range = (max(0, 1 - color_jitter), 1 + color_jitter) | |||
| random_color_jitter_op = py_vision.RandomColorAdjust(brightness=adjust_range, | |||
| contrast=adjust_range, | |||
| saturation=adjust_range) | |||
| decode_p = py_vision.Decode() | |||
| resize_p = py_vision.Resize(int(resize_height), interpolation=Inter.BICUBIC) | |||
| center_crop_p = py_vision.CenterCrop(resize_height) | |||
| totensor = py_vision.ToTensor() | |||
| normalize_p = py_vision.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |||
| if do_train: | |||
| trans = py_transforms.Compose([decode_p, random_resize_crop_bicubic, random_horizontal_flip_op, | |||
| random_color_jitter_op, totensor, normalize_p]) | |||
| else: | |||
| trans = py_transforms.Compose([decode_p, resize_p, center_crop_p, totensor, normalize_p]) | |||
| type_cast_op = c_transforms.TypeCast(mstype.int32) | |||
| ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8) | |||
| ds = ds.map(input_columns="label_list", operations=type_cast_op, num_parallel_workers=8) | |||
| # apply shuffle operations | |||
| ds = ds.shuffle(buffer_size=buffer_size) | |||
| # apply batch operations | |||
| ds = ds.batch(batch_size, drop_remainder=True) | |||
| # apply dataset repeat operation | |||
| ds = ds.repeat(repeat_num) | |||
| return ds | |||
| @@ -0,0 +1,390 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """TNT""" | |||
| import math | |||
| import copy | |||
| import numpy as np | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common.parameter import Parameter | |||
| class MLP(nn.Cell): | |||
| """MLP""" | |||
| def __init__(self, in_features, hidden_features=None, out_features=None, dropout=0.): | |||
| super(MLP, self).__init__() | |||
| out_features = out_features or in_features | |||
| hidden_features = hidden_features or in_features | |||
| self.fc1 = nn.Dense(in_features, hidden_features) | |||
| self.dropout = nn.Dropout(1. - dropout) | |||
| self.fc2 = nn.Dense(hidden_features, out_features) | |||
| self.act = nn.GELU() | |||
| def construct(self, x): | |||
| x = self.fc1(x) | |||
| x = self.act(x) | |||
| x = self.dropout(x) | |||
| x = self.fc2(x) | |||
| x = self.dropout(x) | |||
| return x | |||
| class Attention(nn.Cell): | |||
| """Multi-head Attention""" | |||
| def __init__(self, dim, hidden_dim=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): | |||
| super(Attention, self).__init__() | |||
| hidden_dim = hidden_dim or dim | |||
| self.hidden_dim = hidden_dim | |||
| self.num_heads = num_heads | |||
| head_dim = hidden_dim // num_heads | |||
| self.head_dim = head_dim | |||
| self.scale = head_dim ** -0.5 | |||
| self.qk = nn.Dense(dim, hidden_dim * 2, has_bias=qkv_bias) | |||
| self.v = nn.Dense(dim, hidden_dim, has_bias=qkv_bias) | |||
| self.softmax = nn.Softmax(axis=-1) | |||
| self.batmatmul_trans_b = P.BatchMatMul(transpose_b=True) | |||
| self.attn_drop = nn.Dropout(1. - attn_drop) | |||
| self.batmatmul = P.BatchMatMul() | |||
| self.proj = nn.Dense(hidden_dim, dim) | |||
| self.proj_drop = nn.Dropout(1. - proj_drop) | |||
| self.transpose = P.Transpose() | |||
| self.reshape = P.Reshape() | |||
| def construct(self, x): | |||
| """Multi-head Attention""" | |||
| B, N, _ = x.shape | |||
| qk = self.transpose(self.reshape(self.qk(x), (B, N, 2, self.num_heads, self.head_dim)), (2, 0, 3, 1, 4)) | |||
| q, k = qk[0], qk[1] | |||
| v = self.transpose(self.reshape(self.v(x), (B, N, self.num_heads, self.head_dim)), (0, 2, 1, 3)) | |||
| attn = self.softmax(self.batmatmul_trans_b(q, k) * self.scale) | |||
| attn = self.attn_drop(attn) | |||
| x = self.reshape(self.transpose(self.batmatmul(attn, v), (0, 2, 1, 3)), (B, N, -1)) | |||
| x = self.proj(x) | |||
| x = self.proj_drop(x) | |||
| return x | |||
| class DropConnect(nn.Cell): | |||
| """drop connect implementation""" | |||
| def __init__(self, drop_connect_rate=0., seed0=0, seed1=0): | |||
| super(DropConnect, self).__init__() | |||
| self.shape = P.Shape() | |||
| self.dtype = P.DType() | |||
| self.keep_prob = 1 - drop_connect_rate | |||
| self.dropout = P.Dropout(keep_prob=self.keep_prob) | |||
| self.keep_prob_tensor = Tensor(self.keep_prob, dtype=mstype.float32) | |||
| def construct(self, x): | |||
| shape = self.shape(x) | |||
| dtype = self.dtype(x) | |||
| ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1) | |||
| _, mask = self.dropout(ones_tensor) | |||
| x = x * mask | |||
| x = x / self.keep_prob_tensor | |||
| return x | |||
| class Pixel2Patch(nn.Cell): | |||
| """Projecting Pixel Embedding to Patch Embedding""" | |||
| def __init__(self, outer_dim): | |||
| super(Pixel2Patch, self).__init__() | |||
| self.norm_proj = nn.LayerNorm([outer_dim]) | |||
| self.proj = nn.Dense(outer_dim, outer_dim) | |||
| self.fake = Parameter(Tensor(np.zeros((1, 1, outer_dim)), | |||
| mstype.float32), name='fake', requires_grad=False) | |||
| self.reshape = P.Reshape() | |||
| self.tile = P.Tile() | |||
| self.concat = P.Concat(axis=1) | |||
| def construct(self, pixel_embed, patch_embed): | |||
| B, N, _ = patch_embed.shape | |||
| proj = self.reshape(pixel_embed, (B, N - 1, -1)) | |||
| proj = self.proj(self.norm_proj(proj)) | |||
| proj = self.concat((self.tile(self.fake, (B, 1, 1)), proj)) | |||
| patch_embed = patch_embed + proj | |||
| return patch_embed | |||
| class TNTBlock(nn.Cell): | |||
| """TNT Block""" | |||
| def __init__(self, inner_config, outer_config, dropout=0., attn_dropout=0., drop_connect=0.): | |||
| super().__init__() | |||
| # inner transformer | |||
| inner_dim = inner_config['dim'] | |||
| num_heads = inner_config['num_heads'] | |||
| mlp_ratio = inner_config['mlp_ratio'] | |||
| self.inner_norm1 = nn.LayerNorm([inner_dim]) | |||
| self.inner_attn = Attention(inner_dim, num_heads=num_heads, qkv_bias=True, attn_drop=attn_dropout, | |||
| proj_drop=dropout) | |||
| self.inner_norm2 = nn.LayerNorm([inner_dim]) | |||
| self.inner_mlp = MLP(inner_dim, int(inner_dim * mlp_ratio), dropout=dropout) | |||
| # outer transformer | |||
| outer_dim = outer_config['dim'] | |||
| num_heads = outer_config['num_heads'] | |||
| mlp_ratio = outer_config['mlp_ratio'] | |||
| self.outer_norm1 = nn.LayerNorm([outer_dim]) | |||
| self.outer_attn = Attention(outer_dim, num_heads=num_heads, qkv_bias=True, attn_drop=attn_dropout, | |||
| proj_drop=dropout) | |||
| self.outer_norm2 = nn.LayerNorm([outer_dim]) | |||
| self.outer_mlp = MLP(outer_dim, int(outer_dim * mlp_ratio), dropout=dropout) | |||
| # pixel2patch | |||
| self.pixel2patch = Pixel2Patch(outer_dim) | |||
| # assistant | |||
| self.drop_connect = DropConnect(drop_connect) | |||
| self.reshape = P.Reshape() | |||
| self.tile = P.Tile() | |||
| self.concat = P.Concat(axis=1) | |||
| def construct(self, pixel_embed, patch_embed): | |||
| """TNT Block""" | |||
| pixel_embed = pixel_embed + self.inner_attn(self.inner_norm1(pixel_embed)) | |||
| pixel_embed = pixel_embed + self.inner_mlp(self.inner_norm2(pixel_embed)) | |||
| patch_embed = self.pixel2patch(pixel_embed, patch_embed) | |||
| patch_embed = patch_embed + self.outer_attn(self.outer_norm1(patch_embed)) | |||
| patch_embed = patch_embed + self.outer_mlp(self.outer_norm2(patch_embed)) | |||
| return pixel_embed, patch_embed | |||
| def _get_clones(module, N): | |||
| """get_clones""" | |||
| return nn.CellList([copy.deepcopy(module) for i in range(N)]) | |||
| class TNTEncoder(nn.Cell): | |||
| """TNT""" | |||
| def __init__(self, encoder_layer, num_layers): | |||
| super().__init__() | |||
| self.layers = _get_clones(encoder_layer, num_layers) | |||
| self.num_layers = num_layers | |||
| def construct(self, pixel_embed, patch_embed): | |||
| """TNT""" | |||
| for layer in self.layers: | |||
| pixel_embed, patch_embed = layer(pixel_embed, patch_embed) | |||
| return pixel_embed, patch_embed | |||
| class _stride_unfold_(nn.Cell): | |||
| """Unfold with stride""" | |||
| def __init__( | |||
| self, kernel_size, stride=-1): | |||
| super(_stride_unfold_, self).__init__() | |||
| if stride == -1: | |||
| self.stride = kernel_size | |||
| else: | |||
| self.stride = stride | |||
| self.kernel_size = kernel_size | |||
| self.reshape = P.Reshape() | |||
| self.transpose = P.Transpose() | |||
| self.unfold = _unfold_(kernel_size) | |||
| def construct(self, x): | |||
| """TNT""" | |||
| N, C, H, W = x.shape | |||
| leftup_idx_x = [] | |||
| leftup_idx_y = [] | |||
| nh = int((H - self.kernel_size) / self.stride + 1) | |||
| nw = int((W - self.kernel_size) / self.stride + 1) | |||
| for i in range(nh): | |||
| leftup_idx_x.append(i * self.stride) | |||
| for i in range(nw): | |||
| leftup_idx_y.append(i * self.stride) | |||
| NumBlock_x = len(leftup_idx_x) | |||
| NumBlock_y = len(leftup_idx_y) | |||
| zeroslike = P.ZerosLike() | |||
| cc_2 = P.Concat(axis=2) | |||
| cc_3 = P.Concat(axis=3) | |||
| unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size, | |||
| NumBlock_y * self.kernel_size), mstype.float32) | |||
| N, C, H, W = unf_x.shape | |||
| for i in range(NumBlock_x): | |||
| for j in range(NumBlock_y): | |||
| unf_i = i * self.kernel_size | |||
| unf_j = j * self.kernel_size | |||
| org_i = leftup_idx_x[i] | |||
| org_j = leftup_idx_y[j] | |||
| fill = x[:, :, org_i:org_i + self.kernel_size, | |||
| org_j:org_j + self.kernel_size] | |||
| unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]), | |||
| cc_2((cc_2((zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fill)), | |||
| zeroslike(unf_x[:, :, unf_i + self.kernel_size:, | |||
| unf_j:unf_j + self.kernel_size]))))), | |||
| zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:]))) | |||
| y = self.unfold(unf_x) | |||
| return y | |||
| class _unfold_(nn.Cell): | |||
| """Unfold""" | |||
| def __init__( | |||
| self, kernel_size, stride=-1): | |||
| super(_unfold_, self).__init__() | |||
| if stride == -1: | |||
| self.stride = kernel_size | |||
| self.kernel_size = kernel_size | |||
| self.reshape = P.Reshape() | |||
| self.transpose = P.Transpose() | |||
| def construct(self, x): | |||
| """TNT""" | |||
| N, C, H, W = x.shape | |||
| numH = int(H / self.kernel_size) | |||
| numW = int(W / self.kernel_size) | |||
| if numH * self.kernel_size != H or numW * self.kernel_size != W: | |||
| x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size] | |||
| output_img = self.reshape(x, (N, C, numH, self.kernel_size, W)) | |||
| output_img = self.transpose(output_img, (0, 1, 2, 4, 3)) | |||
| output_img = self.reshape(output_img, (N, C, int( | |||
| numH * numW), self.kernel_size, self.kernel_size)) | |||
| output_img = self.transpose(output_img, (0, 2, 1, 4, 3)) | |||
| output_img = self.reshape(output_img, (N, int(numH * numW), -1)) | |||
| return output_img | |||
| class PixelEmbed(nn.Cell): | |||
| """Image to Pixel Embedding""" | |||
| def __init__(self, img_size, patch_size=16, in_channels=3, embedding_dim=768, stride=4): | |||
| super(PixelEmbed, self).__init__() | |||
| self.num_patches = (img_size // patch_size) * (img_size // patch_size) | |||
| new_patch_size = math.ceil(patch_size / stride) | |||
| self.new_patch_size = new_patch_size | |||
| self.inner_dim = embedding_dim // new_patch_size // new_patch_size | |||
| self.proj = nn.Conv2d(in_channels, self.inner_dim, kernel_size=7, pad_mode='pad', | |||
| padding=3, stride=stride, has_bias=True) | |||
| self.unfold = _unfold_(kernel_size=new_patch_size) | |||
| self.reshape = P.Reshape() | |||
| self.transpose = P.Transpose() | |||
| def construct(self, x): | |||
| B = x.shape[0] | |||
| x = self.proj(x) # B, C, H, W | |||
| x = self.unfold(x) # B, N, Ck2 | |||
| x = self.reshape(x, (B * self.num_patches, self.inner_dim, -1)) # B*N, C, M | |||
| x = self.transpose(x, (0, 2, 1)) # B*N, M, C | |||
| return x | |||
| class TNT(nn.Cell): | |||
| """TNT""" | |||
| def __init__( | |||
| self, | |||
| img_size, | |||
| patch_size, | |||
| num_channels, | |||
| embedding_dim, | |||
| num_heads, | |||
| num_layers, | |||
| hidden_dim, | |||
| num_class, | |||
| stride=4, | |||
| dropout=0, | |||
| attn_dropout=0, | |||
| drop_connect=0.1 | |||
| ): | |||
| super(TNT, self).__init__() | |||
| assert embedding_dim % num_heads == 0 | |||
| assert img_size % patch_size == 0 | |||
| self.embedding_dim = embedding_dim | |||
| self.num_heads = num_heads | |||
| self.patch_size = patch_size | |||
| self.num_channels = num_channels | |||
| self.img_size = img_size | |||
| self.num_patches = int((img_size // patch_size) ** 2) | |||
| new_patch_size = math.ceil(patch_size / stride) | |||
| inner_dim = embedding_dim // new_patch_size // new_patch_size | |||
| self.patch_pos = Parameter(Tensor(np.random.rand(1, self.num_patches + 1, embedding_dim), | |||
| mstype.float32), name='patch_pos', requires_grad=True) | |||
| self.pixel_pos = Parameter(Tensor(np.random.rand(1, inner_dim, new_patch_size * new_patch_size), | |||
| mstype.float32), name='pixel_pos', requires_grad=True) | |||
| self.cls_token = Parameter(Tensor(np.random.rand(1, 1, embedding_dim), | |||
| mstype.float32), requires_grad=True) | |||
| self.patch_embed = Parameter(Tensor(np.zeros((1, self.num_patches, embedding_dim)), | |||
| mstype.float32), name='patch_embed', requires_grad=False) | |||
| self.fake = Parameter(Tensor(np.zeros((1, 1, embedding_dim)), | |||
| mstype.float32), name='fake', requires_grad=False) | |||
| self.pos_drop = nn.Dropout(1. - dropout) | |||
| self.pixel_embed = PixelEmbed(img_size, patch_size, num_channels, embedding_dim, stride) | |||
| self.pixel2patch = Pixel2Patch(embedding_dim) | |||
| inner_config = {'dim': inner_dim, 'num_heads': 4, 'mlp_ratio': 4} | |||
| outer_config = {'dim': embedding_dim, 'num_heads': num_heads, 'mlp_ratio': hidden_dim / embedding_dim} | |||
| encoder_layer = TNTBlock(inner_config, outer_config, dropout=dropout, attn_dropout=attn_dropout, | |||
| drop_connect=drop_connect) | |||
| self.encoder = TNTEncoder(encoder_layer, num_layers) | |||
| self.head = nn.SequentialCell( | |||
| nn.LayerNorm([embedding_dim]), | |||
| nn.Dense(embedding_dim, num_class) | |||
| ) | |||
| self.add = P.TensorAdd() | |||
| self.reshape = P.Reshape() | |||
| self.concat = P.Concat(axis=1) | |||
| self.tile = P.Tile() | |||
| self.transpose = P.Transpose() | |||
| def construct(self, x): | |||
| """TNT""" | |||
| B, _, _, _ = x.shape | |||
| pixel_embed = self.pixel_embed(x) | |||
| pixel_embed = pixel_embed + self.transpose(self.pixel_pos, (0, 2, 1)) # B*N, M, C | |||
| patch_embed = self.concat((self.cls_token, self.patch_embed)) | |||
| patch_embed = self.tile(patch_embed, (B, 1, 1)) | |||
| patch_embed = self.pos_drop(patch_embed + self.patch_pos) | |||
| patch_embed = self.pixel2patch(pixel_embed, patch_embed) | |||
| pixel_embed, patch_embed = self.encoder(pixel_embed, patch_embed) | |||
| y = self.head(patch_embed[:, 0]) | |||
| return y | |||
| def tnt_b(num_class): | |||
| return TNT(img_size=384, | |||
| patch_size=16, | |||
| num_channels=3, | |||
| embedding_dim=640, | |||
| num_heads=10, | |||
| num_layers=12, | |||
| hidden_dim=640*4, | |||
| stride=4, | |||
| num_class=num_class) | |||