You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resnet_model.py 7.5 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # !/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. """
  4. /**
  5. * Copyright 2020 Tianshu AI Platform. All Rights Reserved.
  6. *
  7. * Licensed under the Apache License, Version 2.0 (the "License");
  8. * you may not use this file except in compliance with the License.
  9. * You may obtain a copy of the License at
  10. *
  11. * http://www.apache.org/licenses/LICENSE-2.0
  12. *
  13. * Unless required by applicable law or agreed to in writing, software
  14. * distributed under the License is distributed on an "AS IS" BASIS,
  15. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. * See the License for the specific language governing permissions and
  17. * limitations under the License.
  18. * =============================================================
  19. */
  20. Reference:
  21. - [Identity Mappings in Deep Residual Networks]
  22. (https://arxiv.org/abs/1603.05027) (CVPR 2016)
  23. """
  24. from __future__ import absolute_import
  25. from __future__ import division
  26. from __future__ import print_function
  27. import oneflow as flow
  28. BLOCK_COUNTS = [3, 4, 6, 3]
  29. BLOCK_FILTERS = [256, 512, 1024, 2048]
  30. BLOCK_FILTERS_INNER = [64, 128, 256, 512]
  31. class ResnetBuilder(object):
  32. def __init__(self, weight_regularizer, trainable=True, training=True):
  33. self.weight_initializer = flow.variance_scaling_initializer(
  34. 2, 'fan_in', 'random_normal', data_format="NCHW")
  35. self.weight_regularizer = weight_regularizer
  36. self.trainable = trainable
  37. self.training = training
  38. def _conv2d(
  39. self,
  40. name,
  41. input,
  42. filters,
  43. kernel_size,
  44. strides=1,
  45. padding="SAME",
  46. data_format="NCHW",
  47. dilations=1,
  48. ):
  49. weight = flow.get_variable(
  50. name + "-weight",
  51. shape=(filters, input.shape[1], kernel_size, kernel_size),
  52. dtype=input.dtype,
  53. initializer=self.weight_initializer,
  54. regularizer=self.weight_regularizer,
  55. model_name="weight",
  56. trainable=self.trainable,
  57. )
  58. return flow.nn.conv2d(
  59. input,
  60. weight,
  61. strides,
  62. padding,
  63. data_format,
  64. dilations,
  65. name=name)
  66. def _batch_norm(self, inputs, name=None, last=False):
  67. initializer = flow.zeros_initializer() if last else flow.ones_initializer()
  68. return flow.layers.batch_normalization(
  69. inputs=inputs,
  70. axis=1,
  71. momentum=0.9, # 97,
  72. epsilon=1e-5,
  73. center=True,
  74. scale=True,
  75. trainable=self.trainable,
  76. training=self.training,
  77. gamma_initializer=initializer,
  78. moving_variance_initializer=initializer,
  79. gamma_regularizer=self.weight_regularizer,
  80. beta_regularizer=self.weight_regularizer,
  81. name=name,
  82. )
  83. def conv2d_affine(
  84. self,
  85. input,
  86. name,
  87. filters,
  88. kernel_size,
  89. strides,
  90. activation=None,
  91. last=False):
  92. # input data_format must be NCHW, cannot check now
  93. padding = "SAME" if strides > 1 or kernel_size > 1 else "VALID"
  94. output = self._conv2d(
  95. name,
  96. input,
  97. filters,
  98. kernel_size,
  99. strides,
  100. padding)
  101. output = self._batch_norm(output, name + "_bn", last=last)
  102. if activation == "Relu":
  103. output = flow.nn.relu(output)
  104. return output
  105. def bottleneck_transformation(
  106. self,
  107. input,
  108. block_name,
  109. filters,
  110. filters_inner,
  111. strides):
  112. a = self.conv2d_affine(
  113. input,
  114. block_name +
  115. "_branch2a",
  116. filters_inner,
  117. 1,
  118. 1,
  119. activation="Relu")
  120. b = self.conv2d_affine(
  121. a,
  122. block_name +
  123. "_branch2b",
  124. filters_inner,
  125. 3,
  126. strides,
  127. activation="Relu")
  128. c = self.conv2d_affine(
  129. b,
  130. block_name +
  131. "_branch2c",
  132. filters,
  133. 1,
  134. 1,
  135. last=True)
  136. return c
  137. def residual_block(
  138. self,
  139. input,
  140. block_name,
  141. filters,
  142. filters_inner,
  143. strides_init):
  144. if strides_init != 1 or block_name == "res2_0":
  145. shortcut = self.conv2d_affine(
  146. input, block_name + "_branch1", filters, 1, strides_init
  147. )
  148. else:
  149. shortcut = input
  150. bottleneck = self.bottleneck_transformation(
  151. input, block_name, filters, filters_inner, strides_init,
  152. )
  153. return flow.nn.relu(bottleneck + shortcut)
  154. def residual_stage(
  155. self,
  156. input,
  157. stage_name,
  158. counts,
  159. filters,
  160. filters_inner,
  161. stride_init=2):
  162. output = input
  163. for i in range(counts):
  164. block_name = "%s_%d" % (stage_name, i)
  165. output = self.residual_block(
  166. output,
  167. block_name,
  168. filters,
  169. filters_inner,
  170. stride_init if i == 0 else 1)
  171. return output
  172. def resnet_conv_x_body(self, input):
  173. output = input
  174. for i, (counts, filters, filters_inner) in enumerate(
  175. zip(BLOCK_COUNTS, BLOCK_FILTERS, BLOCK_FILTERS_INNER)
  176. ):
  177. stage_name = "res%d" % (i + 2)
  178. output = self.residual_stage(
  179. output,
  180. stage_name,
  181. counts,
  182. filters,
  183. filters_inner,
  184. 1 if i == 0 else 2)
  185. return output
  186. def resnet_stem(self, input):
  187. conv1 = self._conv2d("conv1", input, 64, 7, 2)
  188. conv1_bn = flow.nn.relu(self._batch_norm(conv1, "conv1_bn"))
  189. pool1 = flow.nn.max_pool2d(
  190. conv1_bn,
  191. ksize=3,
  192. strides=2,
  193. padding="SAME",
  194. data_format="NCHW",
  195. name="pool1",
  196. )
  197. return pool1
  198. def resnet50(
  199. images,
  200. trainable=True,
  201. need_transpose=False,
  202. training=True,
  203. wd=1.0 / 32768):
  204. weight_regularizer = flow.regularizers.l2(
  205. wd) if wd > 0.0 and wd < 1.0 else None
  206. builder = ResnetBuilder(weight_regularizer, trainable, training)
  207. # note: images.shape = (N C H W) in cc's new dataloader, transpose is not
  208. # needed anymore
  209. if need_transpose:
  210. images = flow.transpose(images, name="transpose", perm=[0, 3, 1, 2])
  211. with flow.deprecated.variable_scope("Resnet"):
  212. stem = builder.resnet_stem(images)
  213. body = builder.resnet_conv_x_body(stem)
  214. pool5 = flow.nn.avg_pool2d(
  215. body,
  216. ksize=7,
  217. strides=1,
  218. padding="VALID",
  219. data_format="NCHW",
  220. name="pool5",
  221. )
  222. fc1001 = flow.layers.dense(
  223. flow.reshape(pool5, (pool5.shape[0], -1)),
  224. units=1000,
  225. use_bias=True,
  226. kernel_initializer=flow.variance_scaling_initializer(2, 'fan_in', 'random_normal'),
  227. bias_initializer=flow.zeros_initializer(),
  228. kernel_regularizer=weight_regularizer,
  229. bias_regularizer=weight_regularizer,
  230. trainable=trainable,
  231. name="fc1001",
  232. )
  233. return fc1001

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能