You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Xception.py 6.4 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Xception."""
  16. import mindspore.nn as nn
  17. import mindspore.ops.operations as P
  18. class SeparableConv2d(nn.Cell):
  19. def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
  20. super(SeparableConv2d, self).__init__()
  21. self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, group=in_channels, pad_mode='pad',
  22. padding=padding, weight_init='xavier_uniform')
  23. self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, pad_mode='valid',
  24. weight_init='xavier_uniform')
  25. def construct(self, x):
  26. x = self.conv1(x)
  27. x = self.pointwise(x)
  28. return x
  29. class Block(nn.Cell):
  30. def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
  31. super(Block, self).__init__()
  32. if out_filters != in_filters or strides != 1:
  33. self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, pad_mode='valid', has_bias=False,
  34. weight_init='xavier_uniform')
  35. self.skipbn = nn.BatchNorm2d(out_filters, momentum=0.9)
  36. else:
  37. self.skip = None
  38. self.relu = nn.ReLU()
  39. rep = []
  40. filters = in_filters
  41. if grow_first:
  42. rep.append(nn.ReLU())
  43. rep.append(SeparableConv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
  44. rep.append(nn.BatchNorm2d(out_filters, momentum=0.9))
  45. filters = out_filters
  46. for _ in range(reps - 1):
  47. rep.append(nn.ReLU())
  48. rep.append(SeparableConv2d(filters, filters, kernel_size=3, stride=1, padding=1))
  49. rep.append(nn.BatchNorm2d(filters, momentum=0.9))
  50. if not grow_first:
  51. rep.append(nn.ReLU())
  52. rep.append(SeparableConv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
  53. rep.append(nn.BatchNorm2d(out_filters, momentum=0.9))
  54. if not start_with_relu:
  55. rep = rep[1:]
  56. else:
  57. rep[0] = nn.ReLU()
  58. if strides != 1:
  59. rep.append(nn.MaxPool2d(3, strides, pad_mode="same"))
  60. self.rep = nn.SequentialCell(*rep)
  61. self.add = P.Add()
  62. def construct(self, inp):
  63. x = self.rep(inp)
  64. if self.skip is not None:
  65. skip = self.skip(inp)
  66. skip = self.skipbn(skip)
  67. else:
  68. skip = inp
  69. x = self.add(x, skip)
  70. return x
  71. class Xception(nn.Cell):
  72. """
  73. Xception optimized for the ImageNet dataset, as specified in
  74. https://arxiv.org/abs/1610.02357.pdf
  75. """
  76. def __init__(self, num_classes=1000):
  77. """ Constructor
  78. Args:
  79. num_classes: number of classes.
  80. """
  81. super(Xception, self).__init__()
  82. self.num_classes = num_classes
  83. self.conv1 = nn.Conv2d(3, 32, 3, 2, pad_mode='valid', weight_init='xavier_uniform')
  84. self.bn1 = nn.BatchNorm2d(32, momentum=0.9)
  85. self.relu = nn.ReLU()
  86. self.conv2 = nn.Conv2d(32, 64, 3, pad_mode='valid', weight_init='xavier_uniform')
  87. self.bn2 = nn.BatchNorm2d(64, momentum=0.9)
  88. # Entry flow
  89. self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True)
  90. self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True)
  91. self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True)
  92. # Middle flow
  93. self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
  94. self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
  95. self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
  96. self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
  97. self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
  98. self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
  99. self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
  100. self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
  101. # Exit flow
  102. self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False)
  103. self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
  104. self.bn3 = nn.BatchNorm2d(1536, momentum=0.9)
  105. self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
  106. self.bn4 = nn.BatchNorm2d(2048, momentum=0.9)
  107. self.avg_pool = nn.AvgPool2d(10)
  108. self.dropout = nn.Dropout()
  109. self.fc = nn.Dense(2048, num_classes)
  110. def construct(self, x):
  111. shape = P.Shape()
  112. reshape = P.Reshape()
  113. x = self.conv1(x)
  114. x = self.bn1(x)
  115. x = self.relu(x)
  116. x = self.conv2(x)
  117. x = self.bn2(x)
  118. x = self.relu(x)
  119. x = self.block1(x)
  120. x = self.block2(x)
  121. x = self.block3(x)
  122. x = self.block4(x)
  123. x = self.block5(x)
  124. x = self.block6(x)
  125. x = self.block7(x)
  126. x = self.block8(x)
  127. x = self.block9(x)
  128. x = self.block10(x)
  129. x = self.block11(x)
  130. x = self.block12(x)
  131. x = self.conv3(x)
  132. x = self.bn3(x)
  133. x = self.relu(x)
  134. x = self.conv4(x)
  135. x = self.bn4(x)
  136. x = self.relu(x)
  137. x = self.avg_pool(x)
  138. x = self.dropout(x)
  139. x = reshape(x, (shape(x)[0], -1))
  140. x = self.fc(x)
  141. return x
  142. def xception(class_num=1000):
  143. """
  144. Get Xception neural network.
  145. Args:
  146. class_num (int): Class number.
  147. Returns:
  148. Cell, cell instance of Xception neural network.
  149. Examples:
  150. >>> net = xception(1000)
  151. """
  152. return Xception(class_num)