Merge pull request !6642 from wangmin0104/mastertags/v1.0.0
| @@ -0,0 +1,21 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """hub config.""" | |||||
| from src.resnet_thor import resnet50 | |||||
| def create_network(name, *args, **kwargs): | |||||
| if name == 'resnet50_thor': | |||||
| return resnet50(*args, **kwargs) | |||||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||||
| @@ -273,7 +273,8 @@ class ResNet(nn.Cell): | |||||
| damping, | damping, | ||||
| loss_scale, | loss_scale, | ||||
| frequency, | frequency, | ||||
| batch_size): | |||||
| batch_size, | |||||
| include_top=True): | |||||
| super(ResNet, self).__init__() | super(ResNet, self).__init__() | ||||
| if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: | if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: | ||||
| @@ -321,11 +322,12 @@ class ResNet(nn.Cell): | |||||
| loss_scale=loss_scale, | loss_scale=loss_scale, | ||||
| frequency=frequency, | frequency=frequency, | ||||
| batch_size=batch_size) | batch_size=batch_size) | ||||
| self.mean = P.ReduceMean(keep_dims=True) | |||||
| self.flatten = nn.Flatten() | |||||
| self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale, | |||||
| frequency=frequency, batch_size=batch_size) | |||||
| self.include_top = include_top | |||||
| if self.include_top: | |||||
| self.mean = P.ReduceMean(keep_dims=True) | |||||
| self.flatten = nn.Flatten() | |||||
| self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale, | |||||
| frequency=frequency, batch_size=batch_size) | |||||
| def _make_layer(self, block, layer_num, in_channel, out_channel, stride, | def _make_layer(self, block, layer_num, in_channel, out_channel, stride, | ||||
| damping, loss_scale, frequency, batch_size): | damping, loss_scale, frequency, batch_size): | ||||
| @@ -371,6 +373,9 @@ class ResNet(nn.Cell): | |||||
| c4 = self.layer3(c3) | c4 = self.layer3(c3) | ||||
| c5 = self.layer4(c4) | c5 = self.layer4(c4) | ||||
| if not self.include_top: | |||||
| return x | |||||
| out = self.mean(c5, (2, 3)) | out = self.mean(c5, (2, 3)) | ||||
| out = self.flatten(out) | out = self.flatten(out) | ||||
| out = self.end_point(out) | out = self.end_point(out) | ||||
| @@ -378,7 +383,7 @@ class ResNet(nn.Cell): | |||||
| return out | return out | ||||
| def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278, batch_size=32): | |||||
| def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278, batch_size=32, include_top=True): | |||||
| """ | """ | ||||
| Get ResNet50 neural network. | Get ResNet50 neural network. | ||||
| @@ -400,4 +405,5 @@ def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278, batch_size | |||||
| damping, | damping, | ||||
| loss_scale, | loss_scale, | ||||
| frequency, | frequency, | ||||
| batch_size) | |||||
| batch_size, | |||||
| include_top) | |||||
| @@ -0,0 +1,49 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| ''' | |||||
| Bert hub interface for bert_thor | |||||
| ''' | |||||
| from src.bert_model import BertModel | |||||
| from src.bert_model import BertConfig | |||||
| import mindspore.common.dtype as mstype | |||||
| bert_net_cfg = BertConfig( | |||||
| batch_size=12, | |||||
| seq_length=512, | |||||
| vocab_size=30522, | |||||
| hidden_size=1024, | |||||
| num_hidden_layers=24, | |||||
| num_attention_heads=16, | |||||
| intermediate_size=4096, | |||||
| hidden_act="gelu", | |||||
| hidden_dropout_prob=0.1, | |||||
| attention_probs_dropout_prob=0.1, | |||||
| max_position_embeddings=512, | |||||
| type_vocab_size=2, | |||||
| initializer_range=0.02, | |||||
| use_relative_positions=False, | |||||
| input_mask_from_dataset=True, | |||||
| token_type_ids_from_dataset=True, | |||||
| dtype=mstype.float32, | |||||
| compute_type=mstype.float16, | |||||
| enable_fused_layernorm=True | |||||
| ) | |||||
| def create_network(name, *args, **kwargs): | |||||
| ''' | |||||
| Create bert network for bert_thor. | |||||
| ''' | |||||
| if name == 'bert_thor': | |||||
| is_training = kwargs.get("is_training", default=False) | |||||
| return BertModel(bert_net_cfg, is_training, *args) | |||||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||||