You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.py 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. from blocks import ShuffleV2Block
  17. from mindspore import Tensor
  18. import mindspore.nn as nn
  19. import mindspore.ops.operations as P
  20. class ShuffleNetV2(nn.Cell):
  21. def __init__(self, input_size=224, n_class=1000, model_size='1.0x'):
  22. super(ShuffleNetV2, self).__init__()
  23. print('model size is ', model_size)
  24. self.stage_repeats = [4, 8, 4]
  25. self.model_size = model_size
  26. if model_size == '0.5x':
  27. self.stage_out_channels = [-1, 24, 48, 96, 192, 1024]
  28. elif model_size == '1.0x':
  29. self.stage_out_channels = [-1, 24, 116, 232, 464, 1024]
  30. elif model_size == '1.5x':
  31. self.stage_out_channels = [-1, 24, 176, 352, 704, 1024]
  32. elif model_size == '2.0x':
  33. self.stage_out_channels = [-1, 24, 244, 488, 976, 2048]
  34. else:
  35. raise NotImplementedError
  36. # building first layer
  37. input_channel = self.stage_out_channels[1]
  38. self.first_conv = nn.SequentialCell([
  39. nn.Conv2d(in_channels=3, out_channels=input_channel, kernel_size=3, stride=2,
  40. pad_mode='pad', padding=1, has_bias=False),
  41. nn.BatchNorm2d(num_features=input_channel, momentum=0.9),
  42. nn.ReLU(),
  43. ])
  44. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
  45. self.features = []
  46. for idxstage in range(len(self.stage_repeats)):
  47. numrepeat = self.stage_repeats[idxstage]
  48. output_channel = self.stage_out_channels[idxstage+2]
  49. for i in range(numrepeat):
  50. if i == 0:
  51. self.features.append(ShuffleV2Block(input_channel, output_channel,
  52. mid_channels=output_channel // 2, ksize=3, stride=2))
  53. else:
  54. self.features.append(ShuffleV2Block(input_channel // 2, output_channel,
  55. mid_channels=output_channel // 2, ksize=3, stride=1))
  56. input_channel = output_channel
  57. self.features = nn.SequentialCell([*self.features])
  58. self.conv_last = nn.SequentialCell([
  59. nn.Conv2d(in_channels=input_channel, out_channels=self.stage_out_channels[-1], kernel_size=1, stride=1,
  60. pad_mode='pad', padding=0, has_bias=False),
  61. nn.BatchNorm2d(num_features=self.stage_out_channels[-1], momentum=0.9),
  62. nn.ReLU()
  63. ])
  64. self.globalpool = nn.AvgPool2d(kernel_size=7, stride=7, pad_mode='valid')
  65. if self.model_size == '2.0x':
  66. self.dropout = nn.Dropout(keep_prob=0.8)
  67. self.classifier = nn.SequentialCell([nn.Dense(in_channels=self.stage_out_channels[-1],
  68. out_channels=n_class, has_bias=False)])
  69. ##TODO init weights
  70. self._initialize_weights()
  71. def construct(self, x):
  72. x = self.first_conv(x)
  73. x = self.maxpool(x)
  74. x = self.features(x)
  75. x = self.conv_last(x)
  76. x = self.globalpool(x)
  77. if self.model_size == '2.0x':
  78. x = self.dropout(x)
  79. x = P.Reshape()(x, (-1, self.stage_out_channels[-1],))
  80. x = self.classifier(x)
  81. return x
  82. def _initialize_weights(self):
  83. for name, m in self.cells_and_names():
  84. if isinstance(m, nn.Conv2d):
  85. if 'first' in name:
  86. m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01,
  87. m.weight.data.shape).astype("float32")))
  88. else:
  89. m.weight.set_parameter_data(Tensor(np.random.normal(0, 1.0/m.weight.data.shape[1],
  90. m.weight.data.shape).astype("float32")))
  91. if isinstance(m, nn.Dense):
  92. m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))