Merge pull request !6642 from wangmin0104/mastertags/v1.0.0
| @@ -0,0 +1,21 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """hub config.""" | |||
| from src.resnet_thor import resnet50 | |||
| def create_network(name, *args, **kwargs): | |||
| if name == 'resnet50_thor': | |||
| return resnet50(*args, **kwargs) | |||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||
| @@ -273,7 +273,8 @@ class ResNet(nn.Cell): | |||
| damping, | |||
| loss_scale, | |||
| frequency, | |||
| batch_size): | |||
| batch_size, | |||
| include_top=True): | |||
| super(ResNet, self).__init__() | |||
| if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: | |||
| @@ -321,11 +322,12 @@ class ResNet(nn.Cell): | |||
| loss_scale=loss_scale, | |||
| frequency=frequency, | |||
| batch_size=batch_size) | |||
| self.mean = P.ReduceMean(keep_dims=True) | |||
| self.flatten = nn.Flatten() | |||
| self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale, | |||
| frequency=frequency, batch_size=batch_size) | |||
| self.include_top = include_top | |||
| if self.include_top: | |||
| self.mean = P.ReduceMean(keep_dims=True) | |||
| self.flatten = nn.Flatten() | |||
| self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale, | |||
| frequency=frequency, batch_size=batch_size) | |||
| def _make_layer(self, block, layer_num, in_channel, out_channel, stride, | |||
| damping, loss_scale, frequency, batch_size): | |||
| @@ -371,6 +373,9 @@ class ResNet(nn.Cell): | |||
| c4 = self.layer3(c3) | |||
| c5 = self.layer4(c4) | |||
| if not self.include_top: | |||
| return x | |||
| out = self.mean(c5, (2, 3)) | |||
| out = self.flatten(out) | |||
| out = self.end_point(out) | |||
| @@ -378,7 +383,7 @@ class ResNet(nn.Cell): | |||
| return out | |||
| def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278, batch_size=32): | |||
| def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278, batch_size=32, include_top=True): | |||
| """ | |||
| Get ResNet50 neural network. | |||
| @@ -400,4 +405,5 @@ def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278, batch_size | |||
| damping, | |||
| loss_scale, | |||
| frequency, | |||
| batch_size) | |||
| batch_size, | |||
| include_top) | |||
| @@ -0,0 +1,49 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| ''' | |||
| Bert hub interface for bert_thor | |||
| ''' | |||
| from src.bert_model import BertModel | |||
| from src.bert_model import BertConfig | |||
| import mindspore.common.dtype as mstype | |||
| bert_net_cfg = BertConfig( | |||
| batch_size=12, | |||
| seq_length=512, | |||
| vocab_size=30522, | |||
| hidden_size=1024, | |||
| num_hidden_layers=24, | |||
| num_attention_heads=16, | |||
| intermediate_size=4096, | |||
| hidden_act="gelu", | |||
| hidden_dropout_prob=0.1, | |||
| attention_probs_dropout_prob=0.1, | |||
| max_position_embeddings=512, | |||
| type_vocab_size=2, | |||
| initializer_range=0.02, | |||
| use_relative_positions=False, | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float16, | |||
| enable_fused_layernorm=True | |||
| ) | |||
| def create_network(name, *args, **kwargs): | |||
| ''' | |||
| Create bert network for bert_thor. | |||
| ''' | |||
| if name == 'bert_thor': | |||
| is_training = kwargs.get("is_training", default=False) | |||
| return BertModel(bert_net_cfg, is_training, *args) | |||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||