You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore_nn.py 67 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. from __future__ import absolute_import, division, print_function
  4. import itertools
  5. import mindspore as ms
  6. import mindspore.ops as P
  7. from mindspore import context
  8. from mindspore.nn.cell import Cell
  9. from mindspore._checkparam import Rel
  10. from mindspore.ops import functional as F
  11. from mindspore.communication import management
  12. from mindspore.ops.operations import _inner_ops as inner
  13. from mindspore._extends import cell_attr_register
  14. from mindspore.ops._grad.grad_base import bprop_getters
  15. from mindspore._checkparam import Validator as validator
  16. from mindspore.communication.management import get_group_size, get_rank
  17. def padding_format(padding):
  18. """
  19. Checks that the padding format correspond format.
  20. Parameters
  21. ----------
  22. padding : str
  23. Must be one of the following:"same", "SAME", "VALID", "valid"
  24. Returns
  25. -------
  26. str "SAME" or "VALID"
  27. """
  28. if padding in ["SAME", "same"]:
  29. padding = "same"
  30. elif padding in ["VALID", "valid"]:
  31. padding = "valid"
  32. elif padding == None:
  33. padding = None
  34. else:
  35. raise Exception("Unsupported padding: " + str(padding))
  36. return padding
  37. def preprocess_1d_format(data_format, padding):
  38. """
  39. Checks that the 1-D dataformat format correspond format.
  40. Parameters
  41. ----------
  42. data_format : str
  43. Must be one of the following:"channels_last","NWC","NCW","channels_first"
  44. padding : str
  45. Must be one of the following:"same","valid","SAME","VALID"
  46. Returns
  47. -------
  48. str "NWC" or "NCW" and "SAME" or "VALID"
  49. """
  50. if data_format in ["channels_last", "NWC"]:
  51. data_format = "NWC"
  52. elif data_format in ["channels_first", "NCW"]:
  53. data_format = "NCW"
  54. elif data_format == None:
  55. data_format = None
  56. else:
  57. raise Exception("Unsupported data format: " + str(data_format))
  58. padding = padding_format(padding)
  59. return data_format, padding
  60. def preprocess_2d_format(data_format, padding):
  61. """
  62. Checks that the 2-D dataformat format correspond format.
  63. Parameters
  64. ----------
  65. data_format : str
  66. Must be one of the following:"channels_last","NHWC","NCHW","channels_first"
  67. padding : str
  68. Must be one of the following:"same","valid","SAME","VALID"
  69. Returns
  70. -------
  71. str "NHWC" or "NCHW" and "SAME" or "VALID"
  72. """
  73. if data_format in ["channels_last", "NHWC", "nhwc"]:
  74. data_format = "NHWC"
  75. elif data_format in ["channels_first", "NCHW", "nchw"]:
  76. data_format = "NCHW"
  77. elif data_format == None:
  78. data_format = None
  79. else:
  80. raise Exception("Unsupported data format: " + str(data_format))
  81. padding = padding_format(padding)
  82. return data_format, padding
  83. def preprocess_3d_format(data_format, padding):
  84. """
  85. Checks that the 3-D dataformat format correspond format.
  86. Parameters
  87. ----------
  88. data_format : str
  89. Must be one of the following:"channels_last","NDHWC","NCDHW","channels_first"
  90. padding : str
  91. Must be one of the following:"same","valid","SAME","VALID"
  92. Returns
  93. -------
  94. str "NDHWC" or "NCDHW" and "SAME" or "VALID"
  95. """
  96. if data_format in ['channels_last', 'NDHWC']:
  97. data_format = 'NDHWC'
  98. elif data_format in ['channels_first', 'NCDHW']:
  99. data_format = 'NCDHW'
  100. elif data_format == None:
  101. data_format = None
  102. else:
  103. raise Exception("Unsupported data format: " + str(data_format))
  104. padding = padding_format(padding)
  105. return data_format, padding
  106. def nchw_to_nhwc(x):
  107. """
  108. Channels first to channels last
  109. Parameters
  110. ----------
  111. x : tensor
  112. channels first tensor data
  113. Returns
  114. -------
  115. channels last tensor data
  116. """
  117. if len(P.Shape()(x)) == 3:
  118. x = P.Transpose()(x, (0, 2, 1))
  119. elif len(P.Shape()(x)) == 4:
  120. x = P.Transpose()(x, (0, 2, 3, 1))
  121. elif len(P.Shape()(x)) == 5:
  122. x = P.Transpose()(x, (0, 2, 3, 4, 1))
  123. # else:
  124. # raise Exception("Unsupported dimensions")
  125. return x
  126. def nhwc_to_nchw(x):
  127. """
  128. Channles last to channels first
  129. Parameters
  130. ----------
  131. x : tensor
  132. channels last tensor data
  133. Returns
  134. -------
  135. channels first tensor data
  136. """
  137. if len(P.Shape()(x)) == 3:
  138. x = P.Transpose()(x, (0, 2, 1))
  139. elif len(P.Shape()(x)) == 4:
  140. x = P.Transpose()(x, (0, 3, 1, 2))
  141. elif len(P.Shape()(x)) == 5:
  142. x = P.Transpose()(x, (0, 4, 1, 2, 3))
  143. # else:
  144. # raise Exception("Unsupported dimensions")
  145. return x
  146. class ReLU(Cell):
  147. def __init__(self):
  148. super(ReLU, self).__init__()
  149. self.relu = P.ReLU()
  150. def construct(self, x):
  151. return self.relu(x)
  152. def relu(x):
  153. """
  154. Computes rectified linear: max(features, 0).
  155. Parameters
  156. ----------
  157. x : tensor
  158. Must be one of the following types: float32, float64, int32, uint8, int16,
  159. int8, int64, bfloat16, uint16, half, uint32, uint64, qint8.
  160. Returns
  161. -------
  162. A Tensor. Has the same type as features.
  163. """
  164. outputs = P.ReLU()
  165. return outputs(x)
  166. class ReLU6(Cell):
  167. def __init__(self):
  168. super(ReLU6, self).__init__()
  169. self.relu6 = P.ReLU6()
  170. def construct(self, x):
  171. return self.relu6(x)
  172. def relu6(x):
  173. """
  174. Computes Rectified Linear 6: min(max(features, 0), 6).
  175. Parameters
  176. ----------
  177. x : tensor
  178. Must be one of the following types: float32, float64, int32, uint8, int16,
  179. int8, int64, bfloat16, uint16, half, uint32, uint64, qint8.
  180. Returns
  181. -------
  182. A Tensor with the same type as features.
  183. """
  184. outputs = P.ReLU6()
  185. return outputs(x)
  186. class LeakyReLU(Cell):
  187. def __init__(self, alpha=0.2):
  188. super(LeakyReLU, self).__init__()
  189. self.leakyrelu = ms.nn.LeakyReLU(alpha=alpha)
  190. def construct(self, x):
  191. return self.leakyrelu(x)
  192. def leaky_relu(x, alpha=0.2):
  193. """
  194. Compute the Leaky ReLU activation function.
  195. Parameters
  196. ----------
  197. x : tensor
  198. representing preactivation values. Must be one of the following types:
  199. float16, float32, float64, int32, int64.
  200. Returns
  201. -------
  202. The activation value.
  203. """
  204. leaky_relu = LeakyReLU(alpha=alpha)
  205. output = leaky_relu(x)
  206. return leaky_relu
  207. class Softplus(Cell):
  208. def __init__(self):
  209. super(Softplus, self).__init__()
  210. self.softplus = P.Softplus()
  211. def construct(self, x):
  212. return self.softplus(x)
  213. def softplus(x):
  214. """
  215. Computes softplus: log(exp(features) + 1).
  216. Parameters
  217. ----------
  218. x : tensor
  219. Must be one of the following types: half, bfloat16, float32, float64.
  220. Returns
  221. -------
  222. A Tensor. Has the same type as features.
  223. """
  224. obj = Softplus()
  225. return obj(x)
  226. class Tanh(Cell):
  227. def __init__(self):
  228. super(Tanh, self).__init__()
  229. self.tanh = P.Tanh()
  230. def construct(self, x):
  231. return self.tanh(x)
  232. def tanh(x):
  233. """
  234. Computes hyperbolic tangent of x element-wise.
  235. Parameters
  236. ----------
  237. x : tensor
  238. Must be one of the following types: bfloat16, half, float32, float64, complex64, complex128.
  239. Returns
  240. -------
  241. A Tensor. Has the same type as x.
  242. """
  243. _tanh = Tanh()
  244. return _tanh(x)
  245. class Sigmoid(Cell):
  246. def __init__(self):
  247. super(Sigmoid, self).__init__()
  248. self.sigmoid = P.Sigmoid()
  249. def construct(self, x):
  250. return self.sigmoid(x)
  251. def sigmoid(x):
  252. """
  253. Computes sigmoid of x element-wise.
  254. Parameters
  255. ----------
  256. x : tensor
  257. A Tensor with type float16, float32, float64, complex64, or complex128.
  258. Returns
  259. -------
  260. A Tensor with the same type as x.
  261. """
  262. outputs = P.Sigmoid()
  263. return outputs(x)
  264. class Softmax(Cell):
  265. def __init__(self):
  266. super(Softmax, self).__init__()
  267. self.softmax = P.Softmax()
  268. def construct(self, x):
  269. return self.softmax(x)
  270. def softmax(logits, axis=None):
  271. """
  272. Computes softmax activations.
  273. Parameters
  274. ----------
  275. logits : tensor
  276. Must be one of the following types: half, float32, float64.
  277. axis : int
  278. The dimension softmax would be performed on. The default is -1 which indicates the last dimension.
  279. Returns
  280. -------
  281. A Tensor. Has the same type and shape as logits.
  282. """
  283. outputs = P.Softmax(axis)
  284. return outputs(logits)
  285. class Dropout(Cell):
  286. def __init__(self, keep, seed=0):
  287. super(Dropout, self).__init__()
  288. self.dropout = P.Dropout(keep_prob=keep)
  289. self.is_gpu = context.get_context('device_target') in ["GPU"]
  290. self.get_shape = P.Shape()
  291. self.dropout_gen_mask = P.DropoutGenMask(Seed0=seed, Seed1=0)
  292. self.dropout_do_mask = P.DropoutDoMask()
  293. self.cast = P.Cast()
  294. self.keep_prob = keep # ms.Tensor(keep, dtype=ms.float32)
  295. # print(self.keep_prob, type(self.keep_prob))
  296. def construct(self, inputs):
  297. if self.is_gpu:
  298. outputs, _ = self.dropout(inputs)
  299. return outputs
  300. if self.keep_prob == 1:
  301. return inputs
  302. shape = self.get_shape(inputs)
  303. dtype = P.DType()(inputs)
  304. if self._is_float_dtype(dtype):
  305. keep_prob = self.cast(self.keep_prob, dtype=dtype)
  306. else:
  307. keep_prob = self.cast(self.keep_prob, ms.float16)
  308. output = self.dropout_gen_mask(shape, keep_prob)
  309. return self.dropout_do_mask(inputs, output, keep_prob)
  310. def _is_float_dtype(dtype):
  311. if dtype in [ms.float32, ms.float16]:
  312. return True
  313. return False
  314. class BiasAdd(Cell):
  315. """
  316. Adds bias to value.
  317. Parameters
  318. ----------
  319. x : tensor
  320. A Tensor with type float, double, int64, int32, uint8, int16, int8, complex64, or complex128.
  321. bias : tensor
  322. Must be the same type as value unless value is a quantized type,
  323. in which case a different quantized type may be used.
  324. Returns
  325. -------
  326. A Tensor with the same type as value.
  327. """
  328. def __init__(self, data_format='channels_first'):
  329. super(BiasAdd, self).__init__()
  330. self.bias_add = P.BiasAdd()
  331. if data_format in ['channels_first', 'NCW', 'NCHW', 'NCDHW']:
  332. self.data_format = 'channels_first'
  333. elif data_format in ['channels_last', 'NWC', 'NHWC', 'NDHWC']:
  334. self.data_format = 'channels_last'
  335. else:
  336. raise ("Unsupported data format: " + str(data_format))
  337. def construct(self, x, bias):
  338. if self.data_format == 'channels_last':
  339. x = nhwc_to_nchw(x)
  340. outputs = self.bias_add(x, bias)
  341. if self.data_format == 'channels_last':
  342. outputs = nchw_to_nhwc(outputs)
  343. return outputs
  344. def bias_add(x, bias):
  345. """
  346. Adds bias to value.
  347. Parameters
  348. ----------
  349. x : tensor
  350. A Tensor with type float, double, int64, int32, uint8, int16, int8, complex64, or complex128.
  351. bias : tensor
  352. Must be the same type as value unless value is a quantized type,
  353. in which case a different quantized type may be used.
  354. data_format : A string.
  355. 'N...C' and 'NC...' are supported.
  356. name : str
  357. A name for the operation (optional).
  358. Returns
  359. -------
  360. A Tensor with the same type as value.
  361. """
  362. raise NotImplementedError
  363. class Conv1D(Cell):
  364. def __init__(self, stride, padding, data_format='NWC', dilations=None, out_channel=None, k_size=None):
  365. super(Conv1D, self).__init__()
  366. self.data_format, self.padding = preprocess_1d_format(data_format, padding)
  367. self.stride = (1, stride)
  368. self.dilations = (1, dilations)
  369. self.k_size = (1, k_size)
  370. self.out_channel = out_channel
  371. self.conv2d = P.Conv2D(
  372. out_channel=self.out_channel, kernel_size=self.k_size, pad_mode=self.padding, stride=self.stride,
  373. dilation=self.dilations, mode=1, group=1
  374. )
  375. self.expand_dims = P.ExpandDims()
  376. self.squeeze = P.Squeeze(2)
  377. def construct(self, x, filters):
  378. if self.data_format == 'NWC':
  379. x = nhwc_to_nchw(x)
  380. x = self.expand_dims(x, 2)
  381. filters = self.expand_dims(filters, 2)
  382. output = self.conv2d(x, filters)
  383. output = self.squeeze(output)
  384. if self.data_format == 'NWC':
  385. output = nchw_to_nhwc(output)
  386. return output
  387. def conv1d(input, filters, stride, padding, data_format='NWC', dilations=None, name=None):
  388. """
  389. Computes a 1-D convolution given 3-D input and filter tensors.
  390. Parameters
  391. ----------
  392. input : tensor
  393. A 3D Tensor. Must be of type float16, float32, or float64
  394. filters : tensor
  395. A 3D Tensor. Must have the same type as input.
  396. stride : int of list
  397. An int or list of ints that has length 1 or 3. The number of entries by which the filter is moved right at each step.
  398. padding : string
  399. 'SAME' or 'VALID'
  400. data_format : string
  401. An optional string from "NWC", "NCW". Defaults to "NWC", the data is stored in the order of
  402. [batch, in_width, in_channels]. The "NCW" format stores data as [batch, in_channels, in_width].
  403. dilations : int or list
  404. An int or list of ints that has length 1 or 3 which defaults to 1.
  405. The dilation factor for each dimension of input. If set to k > 1,
  406. there will be k-1 skipped cells between each filter element on that dimension.
  407. Dilations in the batch and depth dimensions must be 1.
  408. name : string
  409. A name for the operation (optional).
  410. Returns
  411. -------
  412. A Tensor. Has the same type as input.
  413. """
  414. pass
  415. class Conv2D(Cell):
  416. def __init__(self, strides, padding, data_format='NHWC', dilations=None, out_channel=None, k_size=None):
  417. super(Conv2D, self).__init__()
  418. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  419. if self.data_format is 'NHWC':
  420. self.ms_stride = strides[1]
  421. self.ms_dilation = dilations[1]
  422. elif self.data_format is 'NCHW':
  423. self.ms_stride = strides[2]
  424. self.ms_dilation = dilations[2]
  425. self.conv2d = P.Conv2D(
  426. out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride,
  427. dilation=self.ms_dilation, mode=1, group=1, data_format=self.data_format
  428. )
  429. def construct(self, inputs, filters):
  430. outputs = self.conv2d(inputs, filters)
  431. return outputs
  432. def conv2d(input, filters, strides, padding, data_format='NCHW', dilations=None):
  433. """
  434. Computes a 2-D convolution given 4-D input and filters tensors.
  435. Parameters
  436. ----------
  437. input : tensor
  438. Must be one of the following types: half, bfloat16, float32, float64. A 4-D tensor.
  439. The dimension order is interpreted according to the value of data_format, see below for details.
  440. filters : tensor
  441. Must have the same type as input. A 4-D tensor of shape [filter_height, filter_width, in_channels, out_channels]
  442. strides : int of list
  443. The stride of the sliding window for each dimension of input. If a single value is given it is replicated in the H and W dimension.
  444. By default the N and C dimensions are set to 1. The dimension order is determined by the value of data_format, see below for details.
  445. padding : string
  446. "SAME" or "VALID"
  447. data_format : string
  448. "NHWC", "NCHW". Defaults to "NCHW".
  449. dilations : list or ints
  450. list of ints that has length 1, 2 or 4, defaults to 1. The dilation factor for each dimension ofinput.
  451. Returns
  452. -------
  453. A Tensor. Has the same type as input.
  454. """
  455. raise NotImplementedError
  456. class Conv3D(Cell):
  457. def __init__(self, strides, padding, data_format='NDHWC', dilations=None, out_channel=None, k_size=None):
  458. super(Conv3D, self).__init__()
  459. self.data_format, self.padding = preprocess_3d_format(data_format, padding)
  460. if self.data_format is 'NDHWC':
  461. self.ms_stride = strides[1]
  462. self.ms_dilation = dilations[1]
  463. raise NotImplementedError("The optional value for data format. Currently only support “NCDHW”.")
  464. elif self.data_format is 'NCDHW':
  465. self.ms_stride = strides[2]
  466. self.ms_dilation = dilations[2]
  467. self.conv3d = P.Conv3D(
  468. out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride,
  469. dilation=self.ms_dilation, data_format=data_format
  470. )
  471. def construct(self, input, filters):
  472. outputs = self.conv3d(input, filters)
  473. return outputs
  474. def conv3d(input, filters, strides, padding, data_format='NDHWC', dilations=None, name=None):
  475. """
  476. Computes a 3-D convolution given 5-D input and filters tensors.
  477. Parameters
  478. ----------
  479. input : tensor
  480. Must be one of the following types: half, bfloat16, float32, float64.
  481. Shape [batch, in_depth, in_height, in_width, in_channels].
  482. filters : tensor
  483. Must have the same type as input. Shape [filter_depth, filter_height, filter_width, in_channels, out_channels].
  484. in_channels must match between input and filters.
  485. strides : list of ints
  486. A list of ints that has length >= 5. 1-D tensor of length 5.
  487. The stride of the sliding window for each dimension of input.
  488. Must have strides[0] = strides[4] = 1.
  489. padding : string
  490. A string from: "SAME", "VALID". The type of padding algorithm to use.
  491. data_format : string
  492. An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC". The data format of the input and output data.
  493. With the default format "NDHWC", the data is stored in the order of: [batch, in_depth, in_height, in_width, in_channels].
  494. Alternatively, the format could be "NCDHW", the data storage order is: [batch, in_channels, in_depth, in_height, in_width].
  495. dilations : list of ints
  496. Defaults to [1, 1, 1, 1, 1]. 1-D tensor of length 5. The dilation factor for each dimension of input.
  497. If set to k > 1, there will be k-1 skipped cells between each filter element on that dimension.
  498. The dimension order is determined by the value of data_format, see above for details.
  499. Dilations in the batch and depth dimensions must be 1.
  500. name : string
  501. A name for the operation (optional).
  502. Returns
  503. -------
  504. A Tensor. Has the same type as input.
  505. """
  506. raise NotImplementedError
  507. def lrn(inputs, depth_radius, bias, alpha, beta):
  508. """
  509. Local Response Normalization.
  510. Parameters
  511. ----------
  512. inputs : tensor
  513. Must be one of the following types: half, bfloat16, float32. 4-D.
  514. depth_radius : int
  515. Defaults to 5. 0-D. Half-width of the 1-D normalization window.
  516. bias : float
  517. Defaults to 1. An offset (usually positive to avoid dividing by 0).
  518. alpha : float
  519. Defaults to 1. A scale factor, usually positive.
  520. beta : float
  521. Defaults to 0.5. An exponent.
  522. Returns
  523. -------
  524. A Tensor. Has the same type as input.
  525. """
  526. pass
  527. def moments(x, axes, shift=None, keepdims=False):
  528. """
  529. Calculates the mean and variance of x.
  530. Parameters
  531. ----------
  532. x : tensor
  533. A Tensor
  534. axes : ints
  535. Axes along which to compute mean and variance.
  536. shift : int
  537. Not used in the current implementation.
  538. keepdims : bool
  539. produce moments with the same dimensionality as the input.
  540. Returns
  541. -------
  542. Two Tensor objects: mean and variance.
  543. """
  544. pass
  545. class MaxPool1d(Cell):
  546. def __init__(self, ksize, strides, padding, data_format=None):
  547. super(MaxPool1d, self).__init__()
  548. self.data_format, padding = preprocess_1d_format(data_format=data_format, padding=padding)
  549. self.expand = P.ExpandDims()
  550. _strides = (1, strides[0])
  551. _ksize = (1, ksize[0])
  552. if self.data_format == 'NWC':
  553. self.squeeze = P.Squeeze(1)
  554. _data_format = 'NHWC'
  555. if self.data_format == 'NCW':
  556. self.squeeze = P.Squeeze(2)
  557. _data_format = 'NCHW'
  558. self.max_pool = P.MaxPool(kernel_size=_ksize, strides=_strides, pad_mode=padding, data_format=_data_format)
  559. def construct(self, inputs):
  560. if self.data_format == 'NWC':
  561. x = self.expand(inputs, 1)
  562. if self.data_format == 'NCW':
  563. x = self.expand(inputs, 2)
  564. output = self.max_pool(x)
  565. output = self.squeeze(output)
  566. return output
  567. class MaxPool(Cell):
  568. def __init__(self, ksize, strides, padding, data_format=None):
  569. super(MaxPool, self).__init__()
  570. data_format, padding = preprocess_2d_format(data_format=data_format, padding=padding)
  571. if data_format == 'NHWC':
  572. _strides = (strides[1], strides[2])
  573. if data_format == 'NCHW':
  574. _strides = (strides[2], strides[3])
  575. self.maxpool = P.MaxPool(kernel_size=ksize, strides=_strides, pad_mode=padding, data_format=data_format)
  576. def construct(self, inputs):
  577. outputs = self.maxpool(inputs)
  578. return outputs
  579. def max_pool(input, ksize, strides, padding, data_format=None):
  580. """
  581. Performs the max pooling on the input.
  582. Parameters
  583. ----------
  584. input : tensor
  585. Tensor of rank N+2, of shape [batch_size] + input_spatial_shape + [num_channels] if data_format does not start
  586. with "NC" (default), or [batch_size, num_channels] + input_spatial_shape if data_format starts with "NC".
  587. Pooling happens over the spatial dimensions only.
  588. ksize : int or list of ints
  589. An int or list of ints that has length 1, N or N+2.
  590. The size of the window for each dimension of the input tensor.
  591. strides : list or list of ints
  592. An int or list of ints that has length 1, N or N+2.
  593. The stride of the sliding window for each dimension of the input tensor.
  594. padding : string
  595. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  596. Returns
  597. -------
  598. A Tensor of format specified by data_format. The max pooled output tensor.
  599. """
  600. data_format, padding = preprocess_2d_format(data_format=data_format, padding=padding)
  601. if data_format == 'NHWC':
  602. _strides = (strides[1], strides[2])
  603. if data_format == 'NCHW':
  604. _strides = (strides[2], strides[3])
  605. outputs = P.MaxPool(kernel_size=ksize, strides=_strides, pad_mode=padding, data_format=data_format)(input)
  606. return outputs
  607. class AvgPool1d(Cell):
  608. def __init__(self, ksize, strides, padding, data_format=None):
  609. super(AvgPool1d, self).__init__()
  610. self.data_format, self.padding = preprocess_1d_format(data_format=data_format, padding=padding)
  611. self.kernel_size = (1, ksize[0])
  612. self.stride = (1, strides[0])
  613. if self.data_format == 'NWC':
  614. _data_format = 'NHWC'
  615. self.squeeze = P.Squeeze(1)
  616. if self.data_format == 'NCW':
  617. _data_format = 'NCHW'
  618. self.squeeze = P.Squeeze(2)
  619. self.avg_pool = P.AvgPool(
  620. kernel_size=self.kernel_size, strides=self.stride, pad_mode=self.padding, data_format=_data_format
  621. )
  622. self.reduce_mean = P.ReduceMean(keep_dims=True)
  623. self.slice = P.Slice()
  624. self.expand = P.ExpandDims()
  625. self.shape = P.Shape()
  626. def construct(self, inputs):
  627. x = inputs
  628. batch, channel, width = self.shape(inputs)
  629. if width == self.kernel_size[1]:
  630. x = self.reduce_mean(x, 2)
  631. elif width - self.kernel_size[1] < self.stride[1]:
  632. x = self.slice(x, (0, 0, 0), (batch, channel, self.kernel_size[1]))
  633. x = self.reduce_mean(x, 2)
  634. else:
  635. if self.data_format == 'NCW':
  636. x = self.expand(x, 2)
  637. if self.data_format == 'NWC':
  638. x = self.expand(x, 1)
  639. x = self.avg_pool(x)
  640. x = self.squeeze(x)
  641. return x
  642. class AvgPool(Cell):
  643. def __init__(self, ksize, strides, padding, data_format=None):
  644. super(AvgPool, self).__init__()
  645. self.data_format, self.padding = preprocess_2d_format(data_format=data_format, padding=padding)
  646. ms_ksize = ksize[1]
  647. ms_strides = strides[1]
  648. self.avgpool = P.AvgPool(ksize=ms_ksize, strides=ms_strides, padding=padding, data_format=self.data_format)
  649. def construct(self, inputs):
  650. outputs = self.avgpool(inputs)
  651. return outputs
  652. def avg_pool(input, ksize, strides, padding):
  653. """
  654. Performs the avg pooling on the input.
  655. Parameters
  656. ----------
  657. input : tensor
  658. Tensor of rank N+2, of shape [batch_size] + input_spatial_shape + [num_channels]
  659. if data_format does not start with "NC" (default), or [batch_size, num_channels] + input_spatial_shape
  660. if data_format starts with "NC". Pooling happens over the spatial dimensions only.
  661. ksize : int or list of ints
  662. An int or list of ints that has length 1, N or N+2.
  663. The size of the window for each dimension of the input tensor.
  664. strides : int or list of ints
  665. An int or list of ints that has length 1, N or N+2.
  666. The stride of the sliding window for each dimension of the input tensor.
  667. padding : string
  668. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  669. Returns
  670. -------
  671. A Tensor of format specified by data_format. The average pooled output tensor.
  672. """
  673. padding = padding_format(padding)
  674. ms_ksize = ksize[0]
  675. ms_strides = strides[1]
  676. outputs = P.AvgPool(ksize=ms_ksize, strides=ms_strides, padding=padding)
  677. return outputs(input)
  678. class MaxPool3d(Cell):
  679. def __init__(self, ksize, strides, padding, data_format=None):
  680. super(MaxPool3d, self).__init__()
  681. self.data_format, self.padding = preprocess_3d_format(data_format, padding)
  682. if data_format == 'NDHWC':
  683. _strides = (strides[1], strides[2], strides[3])
  684. if data_format == 'NCDHW':
  685. _strides = (strides[2], strides[3], strides[4])
  686. self.max_pool3d = P.MaxPool3D(
  687. kernel_size=ksize, strides=_strides, padding=padding, data_format=self.data_format
  688. )
  689. def __call__(self, inputs):
  690. outputs = self.max_pool3d(inputs)
  691. return outputs
  692. def max_pool3d(input, ksize, strides, padding, data_format=None, name=None):
  693. """
  694. Performs the max pooling on the input.
  695. Parameters
  696. ----------
  697. input : tensor
  698. A 5-D Tensor of the format specified by data_format.
  699. ksize : int or list of ints
  700. An int or list of ints that has length 1, 3 or 5.
  701. The size of the window for each dimension of the input tensor.
  702. strides : int or list of ints
  703. An int or list of ints that has length 1, 3 or 5.
  704. The stride of the sliding window for each dimension of the input tensor.
  705. padding : string
  706. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  707. data_format : string
  708. "NDHWC", "NCDHW". Defaults to "NDHWC". The data format of the input and output data.
  709. With the default format "NDHWC", the data is stored in the order of: [batch, in_depth, in_height, in_width, in_channels].
  710. Alternatively, the format could be "NCDHW", the data storage order is: [batch, in_channels, in_depth, in_height, in_width].
  711. name : string
  712. A name for the operation (optional).
  713. Returns
  714. -------
  715. A Tensor of format specified by data_format. The max pooled output tensor.
  716. """
  717. pass
  718. class AvgPool3d(Cell):
  719. def __init__(self, ksize, strides, padding, data_format=None):
  720. super(AvgPool3d, self).__init__()
  721. self.data_format, self.padding = preprocess_3d_format(data_format, padding)
  722. if data_format == 'NDHWC':
  723. _strides = (strides[1], strides[2], strides[3])
  724. if data_format == 'NCDHW':
  725. _strides = (strides[2], strides[3], strides[4])
  726. raise NotImplementedError
  727. def __call__(self, inputs):
  728. pass
  729. def avg_pool3d(input, ksize, strides, padding, data_format=None, name=None):
  730. """
  731. Performs the average pooling on the input.
  732. Parameters
  733. ----------
  734. input : tensor
  735. A 5-D Tensor of shape [batch, height, width, channels] and type float32, float64, qint8, quint8, or qint32.
  736. ksize : int or list of ints
  737. An int or list of ints that has length 1, 3 or 5. The size of the window for each dimension of the input tensor.
  738. strides : int or list of ints
  739. An int or list of ints that has length 1, 3 or 5.
  740. The stride of the sliding window for each dimension of the input tensor.
  741. padding : string
  742. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  743. data_format : string
  744. 'NDHWC' and 'NCDHW' are supported.
  745. name : string
  746. Optional name for the operation.
  747. Returns
  748. -------
  749. A Tensor with the same type as value. The average pooled output tensor.
  750. """
  751. pass
  752. def pool(input, window_shape, pooling_type, strides=None, padding='VALID', data_format=None, dilations=None, name=None):
  753. """
  754. Performs an N-D pooling operation.
  755. Parameters
  756. ----------
  757. input : tensor
  758. Tensor of rank N+2, of shape [batch_size] + input_spatial_shape + [num_channels]
  759. if data_format does not start with "NC" (default), or [batch_size, num_channels] + input_spatial_shape
  760. if data_format starts with "NC". Pooling happens over the spatial dimensions only.
  761. window_shape : int
  762. Sequence of N ints >= 1.
  763. pooling_type : string
  764. Specifies pooling operation, must be "AVG" or "MAX".
  765. strides : ints
  766. Sequence of N ints >= 1. Defaults to [1]*N. If any value of strides is > 1, then all values of dilation_rate must be 1.
  767. padding : string
  768. The padding algorithm, must be "SAME" or "VALID". Defaults to "SAME".
  769. See the "returns" section of tf.ops.convolution for details.
  770. data_format : string
  771. Specifies whether the channel dimension of the input and output is the last dimension (default, or if data_format does not start with "NC"),
  772. or the second dimension (if data_format starts with "NC").
  773. For N=1, the valid values are "NWC" (default) and "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW".
  774. For N=3, the valid values are "NDHWC" (default) and "NCDHW".
  775. dilations : list of ints
  776. Dilation rate. List of N ints >= 1. Defaults to [1]*N. If any value of dilation_rate is > 1, then all values of strides must be 1.
  777. name : string
  778. Optional. Name of the op.
  779. Returns
  780. -------
  781. Tensor of rank N+2, of shape [batch_size] + output_spatial_shape + [num_channels]
  782. """
  783. pass
  784. class DepthwiseConv2d(Cell):
  785. def __init__(self, strides, padding, data_format=None, dilations=None, ksize=None, channel_multiplier=1):
  786. super(DepthwiseConv2d, self).__init__()
  787. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  788. self.ms_stride = strides[1]
  789. self.ms_dilation = dilations[1]
  790. self.depthwise_conv2d = P.DepthwiseConv2dNative(
  791. channel_multiplier=channel_multiplier, kernel_size=ksize, stride=self.ms_stride, dilation=self.ms_dilation
  792. )
  793. def construct(self, input, filter):
  794. if self.data_format == 'NHWC':
  795. input = nhwc_to_nchw(input)
  796. outputs = self.depthwise_conv2d(input, filter)
  797. if self.data_format == 'NHWC':
  798. outputs = nchw_to_nhwc(outputs)
  799. return outputs
  800. def depthwise_conv2d(input, filter, strides, padding, data_format=None, dilations=None, name=None):
  801. """
  802. Depthwise 2-D convolution.
  803. Parameters
  804. ----------
  805. input : tensor
  806. 4-D with shape according to data_format.
  807. filter : tensor
  808. 4-D with shape [filter_height, filter_width, in_channels, channel_multiplier].
  809. strides : list
  810. 1-D of size 4. The stride of the sliding window for each dimension of input.
  811. padding : string
  812. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  813. data_format : string
  814. The data format for input. Either "NHWC" (default) or "NCHW".
  815. dilations : list
  816. 1-D of size 2. The dilation rate in which we sample input values across the height and width dimensions in atrous convolution.
  817. If it is greater than 1, then all values of strides must be 1.
  818. name : string
  819. A name for this operation (optional).
  820. Returns
  821. -------
  822. A 4-D Tensor with shape according to data_format.
  823. E.g., for "NHWC" format, shape is [batch, out_height, out_width, in_channels * channel_multiplier].
  824. """
  825. pass
  826. class Conv1d_transpose(Cell):
  827. def __init__(self, stride, padding, data_format, dilations=None, out_channel=None, k_size=None, in_channels=None):
  828. super(Conv1d_transpose, self).__init__()
  829. self.data_format, self.padding = preprocess_1d_format(data_format, padding)
  830. self.in_channels = in_channels
  831. self.out_channel = out_channel
  832. self.stride = (1, stride)
  833. self.dilations = (1, dilations)
  834. self.k_size = (1, k_size)
  835. if self.data_format == 'NWC':
  836. self.data_format = 'NHWC'
  837. self.h_axis = 1
  838. else:
  839. self.data_format = 'NCHW'
  840. self.h_axis = 2
  841. self.conv2d_transpose = P.Conv2DBackpropInput(
  842. out_channel=self.in_channels, kernel_size=self.k_size, pad_mode=self.padding, stride=self.stride,
  843. dilation=self.dilations, mode=1, group=1, data_format=self.data_format
  844. )
  845. self.shape = P.Shape()
  846. self.expand_dims = P.ExpandDims()
  847. self.squeeze = P.Squeeze(self.h_axis)
  848. def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size):
  849. length = 0
  850. filter_size = filter_size + (filter_size - 1) * (dilation_size - 1)
  851. if self.padding == 'same':
  852. length = input_length * stride_size
  853. elif self.padding == 'valid':
  854. length = input_length * stride_size + max(filter_size - stride_size, 0)
  855. return length
  856. def construct(self, x, filters):
  857. x = self.expand_dims(x, self.h_axis)
  858. filters = self.expand_dims(filters, self.h_axis)
  859. if self.data_format == 'NCHW':
  860. n, _, h, w = self.shape(x)
  861. else:
  862. n, h, w, _ = self.shape(x)
  863. h_out = self._deconv_output_length(h, self.k_size[0], self.stride[0], self.dilations[0])
  864. w_out = self._deconv_output_length(w, self.k_size[1], self.stride[1], self.dilations[1])
  865. if self.data_format == 'NCHW':
  866. output_size = (n, self.out_channel, h_out, w_out)
  867. else:
  868. output_size = (n, h_out, w_out, self.out_channel)
  869. output = self.conv2d_transpose(x, filters, output_size)
  870. output = self.squeeze(output)
  871. return output
  872. def conv1d_transpose(
  873. input, filters, output_shape, strides, padding='SAME', data_format='NWC', dilations=None, name=None
  874. ):
  875. """
  876. The transpose of conv1d.
  877. Parameters
  878. ----------
  879. input : tensor
  880. A 3-D Tensor of type float and shape [batch, in_width, in_channels]
  881. for NWC data format or [batch, in_channels, in_width] for NCW data format.
  882. filters : tensor
  883. A 3-D Tensor with the same type as value and shape [filter_width, output_channels, in_channels].
  884. filter's in_channels dimension must match that of value.
  885. output_shape : tensor
  886. A 1-D Tensor, containing three elements, representing the output shape of the deconvolution op.
  887. strides : list
  888. An int or list of ints that has length 1 or 3. The number of entries by which the filter is moved right at each step.
  889. padding : string
  890. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  891. data_format : string
  892. 'NWC' and 'NCW' are supported.
  893. dilations : list
  894. An int or list of ints that has length 1 or 3 which defaults to 1.
  895. The dilation factor for each dimension of input. If set to k > 1,
  896. there will be k-1 skipped cells between each filter element on that dimension.
  897. Dilations in the batch and depth dimensions must be 1.
  898. name : string
  899. Optional name for the returned tensor.
  900. Returns
  901. -------
  902. A Tensor with the same type as value.
  903. """
  904. pass
  905. class Conv2d_transpose(Cell):
  906. def __init__(self, strides, padding, data_format, dilations=None, out_channel=None, k_size=None, in_channels=None):
  907. super(Conv2d_transpose, self).__init__()
  908. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  909. self.in_channels = in_channels
  910. self.out_channel = out_channel
  911. self.k_size = k_size
  912. self.strides = strides
  913. self.dilations = dilations
  914. self.conv2d_transpose = P.Conv2DBackpropInput(
  915. out_channel=self.in_channels, kernel_size=self.k_size, pad_mode=self.padding, stride=self.strides,
  916. dilation=self.dilations, mode=1, group=1, data_format=self.data_format
  917. )
  918. self.shape = P.Shape()
  919. def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size):
  920. length = 0
  921. filter_size = filter_size + (filter_size - 1) * (dilation_size - 1)
  922. if self.padding == 'same':
  923. length = input_length * stride_size
  924. elif self.padding == 'valid':
  925. length = input_length * stride_size + max(filter_size - stride_size, 0)
  926. return length
  927. def construct(self, x, filters):
  928. if self.data_format == 'NHWC':
  929. h_axis, w_axis = 1, 2
  930. n, h, w, _ = self.shape(x)
  931. else:
  932. h_axis, w_axis = 2, 3
  933. n, _, h, w = self.shape(x)
  934. if isinstance(self.strides, int):
  935. strides_h = self.strides
  936. strides_w = self.strides
  937. else:
  938. strides_list = list(self.strides)
  939. if len(strides_list) == 2:
  940. strides_h = strides_list[0]
  941. strides_w = strides_list[1]
  942. elif len(strides_list) == 4:
  943. strides_h = strides_list[h_axis]
  944. strides_w = strides_list[w_axis]
  945. if self.dilations is not None:
  946. if isinstance(self.dilations, int):
  947. dilations_h = self.dilations
  948. dilations_w = self.dilations
  949. else:
  950. dilations_list = list(self.dilations)
  951. if len(dilations_list) == 2:
  952. dilations_h = dilations_list[0]
  953. dilations_w = dilations_list[1]
  954. elif len(dilations_list) == 4:
  955. dilations_h = dilations_list[h_axis]
  956. dilations_w = dilations_list[w_axis]
  957. h_out = self._deconv_output_length(h, self.k_size[0], strides_h, dilations_h)
  958. w_out = self._deconv_output_length(w, self.k_size[1], strides_w, dilations_w)
  959. if self.data_format == 'NCHW':
  960. output_size = (n, self.out_channel, h_out, w_out)
  961. else:
  962. output_size = (n, h_out, w_out, self.out_channel)
  963. output = self.conv2d_transpose(x, filters, output_size)
  964. return output
  965. def conv2d_transpose(
  966. input, filters, output_shape, strides, padding='SAME', data_format='NHWC', dilations=None, name=None
  967. ):
  968. """
  969. The transpose of conv2d.
  970. Parameters
  971. ----------
  972. input : tensor
  973. A 4-D Tensor of type float and shape [batch, height, width, in_channels]
  974. for NHWC data format or [batch, in_channels, height, width] for NCHW data format.
  975. filters : tensor
  976. A 4-D Tensor with the same type as input and shape [height, width,
  977. output_channels, in_channels]. filter's in_channels dimension must match that of input.
  978. output_shape : tensor
  979. A 1-D Tensor representing the output shape of the deconvolution op.
  980. strides : list
  981. An int or list of ints that has length 1, 2 or 4. The stride of the sliding window for each dimension of input.
  982. If a single value is given it is replicated in the H and W dimension.
  983. By default the N and C dimensions are set to 0.
  984. The dimension order is determined by the value of data_format, see below for details.
  985. padding : string
  986. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  987. data_format : string
  988. 'NHWC' and 'NCHW' are supported.
  989. dilations : list
  990. An int or list of ints that has length 1, 2 or 4, defaults to 1.
  991. name : string
  992. Optional name for the returned tensor.
  993. Returns
  994. -------
  995. A Tensor with the same type as input.
  996. """
  997. pass
  998. class Conv3d_transpose(Cell):
  999. def __init__(
  1000. self, strides, padding, data_format='NDHWC', dilations=None, name=None, out_channel=None, k_size=None,
  1001. in_channels=None
  1002. ):
  1003. super(Conv3d_transpose, self).__init__()
  1004. self.data_format, self.padding = preprocess_3d_format(data_format, padding)
  1005. self.conv3d_transpose = P.Conv3DTranspose(
  1006. in_channel=in_channels, out_channel=out_channel, kernel_size=k_size, mode=1, pad_mode=self.padding,
  1007. stride=strides, dilation=dilations, data_format=self.data_format
  1008. )
  1009. def construct(self, input, filters):
  1010. output = self.conv3d_transpose(input, filters)
  1011. return output
  1012. def conv3d_transpose(
  1013. input, filters, output_shape, strides, padding='SAME', data_format='NDHWC', dilations=None, name=None
  1014. ):
  1015. """
  1016. The transpose of conv3d.
  1017. Parameters
  1018. ----------
  1019. input : tensor
  1020. A 5-D Tensor of type float and shape [batch, height, width, in_channels] for
  1021. NHWC data format or [batch, in_channels, height, width] for NCHW data format.
  1022. filters : tensor
  1023. A 5-D Tensor with the same type as value and shape [height, width, output_channels, in_channels].
  1024. filter's in_channels dimension must match that of value.
  1025. output_shape : tensor
  1026. A 1-D Tensor representing the output shape of the deconvolution op.
  1027. strides : list
  1028. An int or list of ints that has length 1, 3 or 5.
  1029. padding : string
  1030. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  1031. data_format : string
  1032. 'NDHWC' and 'NCDHW' are supported.
  1033. dilations : list of ints
  1034. An int or list of ints that has length 1, 3 or 5, defaults to 1.
  1035. name : string
  1036. Optional name for the returned tensor.
  1037. Returns
  1038. -------
  1039. A Tensor with the same type as value.
  1040. """
  1041. pass
  1042. class BatchNorm(Cell):
  1043. """Batch Normalization base class."""
  1044. @cell_attr_register
  1045. def __init__(
  1046. self, num_features, epsilon=1e-5, decay=0.9, gamma=None, beta=None, moving_mean=None, moving_var=None,
  1047. is_train=None, device_num_each_group=1, process_groups=0, data_format='NCHW'
  1048. ):
  1049. super(BatchNorm, self).__init__()
  1050. if data_format in ["channels_last", "NHWC", "nhwc"]:
  1051. data_format = "NHWC"
  1052. elif data_format in ["channels_first", "NCHW", "nchw"]:
  1053. data_format = "NCHW"
  1054. validator.check_value_type('num_features', num_features, [int], self.cls_name)
  1055. if num_features < 1:
  1056. raise ValueError("num_features must be at least 1")
  1057. if decay < 0 or decay > 1:
  1058. raise ValueError("momentum should be a number in range [0, 1], but got {}".format(decay))
  1059. self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
  1060. if context.get_context("device_target") != "GPU" and self.format == "NHWC":
  1061. raise ValueError("NHWC format only support in GPU target.")
  1062. self.use_batch_statistics = is_train
  1063. self.num_features = num_features
  1064. self.eps = epsilon
  1065. self.moving_mean = moving_mean
  1066. self.moving_variance = moving_var
  1067. self.gamma = gamma
  1068. self.beta = beta
  1069. self.group_device_num = validator.check_positive_int(device_num_each_group)
  1070. self.process_groups = process_groups
  1071. self.is_global = False
  1072. self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
  1073. global SYNC_BN_GROUP_NAME
  1074. # for GlobalBatchNorm
  1075. if self.group_device_num != 1:
  1076. self.rank_id = get_rank()
  1077. self.rank_size = get_group_size()
  1078. self.device_list = [i for i in range(0, self.rank_size)]
  1079. self.rank_list = self.list_group(self.device_list, self.group_device_num)
  1080. self.rank_list_idx = len(self.rank_list)
  1081. for i in range(self.rank_list_idx):
  1082. if self.rank_id in self.rank_list[i]:
  1083. self.is_global = True
  1084. if SYNC_BN_GROUP_NAME == "":
  1085. SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i)
  1086. management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i])
  1087. # for SyncBatchNorm
  1088. if self.process_groups != 0:
  1089. self.rank_id = get_rank()
  1090. self.rank_size = get_group_size()
  1091. if self.process_groups is not None:
  1092. validator.check_isinstance("process_groups", self.process_groups, list)
  1093. self._check_rank_ids(self.process_groups, self.rank_size)
  1094. for i in range(len(self.process_groups)):
  1095. validator.check_isinstance("process_groups[" + str(i) + "]", self.process_groups[i], list)
  1096. self.group_device_num = len(self.process_groups[i])
  1097. if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
  1098. self.is_global = True
  1099. if SYNC_BN_GROUP_NAME == "":
  1100. SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i)
  1101. management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
  1102. elif self.rank_size > 1:
  1103. self.is_global = True
  1104. self.group_device_num = self.rank_size
  1105. self.device_list = [i for i in range(0, self.rank_size)]
  1106. if SYNC_BN_GROUP_NAME == "":
  1107. SYNC_BN_GROUP_NAME = "sync_bn_group0"
  1108. management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
  1109. self.shape = P.Shape()
  1110. self.reduce_mean = P.ReduceMean(keep_dims=True)
  1111. self.square = P.Square()
  1112. self.sqrt = P.Sqrt()
  1113. self.cast = P.Cast()
  1114. self.dtype = P.DType()
  1115. self.reshape = P.Reshape()
  1116. self._target = context.get_context("device_target")
  1117. self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
  1118. self.momentum = 1.0 - decay
  1119. if context.get_context("enable_ge"):
  1120. self.is_ge_backend = True
  1121. else:
  1122. self.is_ge_backend = False
  1123. self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps, momentum=self.momentum, data_format=self.format)
  1124. if self.is_global:
  1125. self.bn_train = inner.SyncBatchNorm(
  1126. epsilon=self.eps, momentum=self.momentum, group=SYNC_BN_GROUP_NAME, device_num=self.group_device_num
  1127. )
  1128. self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
  1129. data_parallel_strategy = ((1, ), (1, ))
  1130. data_parallel_strategy_one = ((1, ), ())
  1131. self.sub_mean = P.Sub().shard(data_parallel_strategy)
  1132. self.sub_var = P.Sub().shard(data_parallel_strategy)
  1133. self.mul_mean = P.Mul().shard(data_parallel_strategy_one)
  1134. self.mul_var = P.Mul().shard(data_parallel_strategy_one)
  1135. self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
  1136. self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
  1137. def list_group(self, world_rank, group_size):
  1138. if group_size > get_group_size():
  1139. raise ValueError(
  1140. "group size can not be greater than local rank size, group size is {}, "
  1141. "local_rank_size is {}".format(group_size, get_group_size())
  1142. )
  1143. if len(world_rank) % group_size != 0:
  1144. raise ValueError("please make your group size correct.")
  1145. world_rank_list = zip(*(iter(world_rank), ) * group_size)
  1146. group_list = [list(i) for i in world_rank_list]
  1147. return group_list
  1148. def _check_rank_ids(self, process_groups, rank_size):
  1149. seen = set()
  1150. for rid in itertools.chain(*process_groups):
  1151. validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups")
  1152. if rid in seen:
  1153. raise ValueError("rank id in process_groups should not be duplicated.")
  1154. seen.add(rid)
  1155. def construct(self, inputs):
  1156. x_shape = F.shape(inputs)
  1157. if len(x_shape) == 5:
  1158. inputs = self.reshape(inputs, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
  1159. flag = self.use_batch_statistics
  1160. if flag:
  1161. output = self.bn_train(inputs, self.gamma, self.beta, self.moving_mean, self.moving_variance)[0]
  1162. if len(x_shape) == 5:
  1163. output = self.reshape(output, x_shape)
  1164. return output
  1165. output = self.bn_infer(inputs, self.gamma, self.beta, self.moving_mean, self.moving_variance)[0]
  1166. if len(x_shape) == 5:
  1167. output = self.reshape(output, x_shape)
  1168. return output
  1169. def extend_repr(self):
  1170. return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format(
  1171. self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance
  1172. )
  1173. class GroupConv2D(Cell):
  1174. def __init__(self, strides, padding, data_format, dilations, out_channel, k_size, groups):
  1175. super(GroupConv2D, self).__init__()
  1176. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  1177. if self.data_format is 'NHWC':
  1178. self.ms_stride = strides[1]
  1179. self.ms_dilation = dilations[1]
  1180. elif self.data_format is 'NCHW':
  1181. self.ms_stride = strides[2]
  1182. self.ms_dilation = dilations[2]
  1183. self.conv2d = P.Conv2D(
  1184. out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride,
  1185. dilation=self.ms_dilation, mode=1, group=groups, data_format=self.data_format
  1186. )
  1187. def construct(self, inputs, filters):
  1188. outputs = self.conv2d(inputs, filters)
  1189. return outputs
  1190. class SeparableConv1D(Cell):
  1191. def __init__(self, stride, padding, data_format, dilations, out_channel, k_size, in_channel, depth_multiplier):
  1192. super(SeparableConv1D, self).__init__()
  1193. self.data_format, self.padding = preprocess_1d_format(data_format, padding)
  1194. self.stride = (1, stride)
  1195. self.dilations = (1, dilations)
  1196. self.k_size = (1, k_size)
  1197. self.out_channel = out_channel
  1198. self.in_channel = in_channel
  1199. self.depth_multiplier = depth_multiplier
  1200. self.depthwise_conv = P.Conv2D(
  1201. out_channel=self.in_channel * self.depth_multiplier, kernel_size=self.k_size, pad_mode=self.padding,
  1202. stride=self.stride, dilation=self.dilations, mode=1, group=self.in_channel
  1203. )
  1204. self.pointwise_conv = P.Conv2D(
  1205. out_channel=self.out_channel, kernel_size=(1, 1), pad_mode=self.padding, stride=(1, 1), dilation=(1, 1),
  1206. mode=1, group=1
  1207. )
  1208. self.expand_dims = P.ExpandDims()
  1209. self.squeeze = P.Squeeze(2)
  1210. def construct(self, x, depthwise_filters, pointwise_filters):
  1211. if self.data_format == 'NWC':
  1212. x = nhwc_to_nchw(x)
  1213. x = self.expand_dims(x, 2)
  1214. depthwise_filters = self.expand_dims(depthwise_filters, 2)
  1215. pointwise_filters = self.expand_dims(pointwise_filters, 2)
  1216. outputs = self.depthwise_conv(x, depthwise_filters)
  1217. outputs = self.pointwise_conv(outputs, pointwise_filters)
  1218. outputs = self.squeeze(outputs)
  1219. if self.data_format == 'NWC':
  1220. outputs = nchw_to_nhwc(outputs)
  1221. return outputs
  1222. class SeparableConv2D(Cell):
  1223. def __init__(self, strides, padding, data_format, dilations, out_channel, k_size, in_channel, depth_multiplier):
  1224. super(SeparableConv2D, self).__init__()
  1225. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  1226. self.k_size = k_size
  1227. self.out_channel = out_channel
  1228. self.in_channel = in_channel
  1229. self.depth_multiplier = depth_multiplier
  1230. if self.data_format is 'NHWC':
  1231. self.ms_stride = strides[1]
  1232. self.ms_dilation = dilations[1]
  1233. elif self.data_format is 'NCHW':
  1234. self.ms_stride = strides[2]
  1235. self.ms_dilation = dilations[2]
  1236. self.depthwise_conv = P.Conv2D(
  1237. out_channel=self.in_channel * self.depth_multiplier, kernel_size=self.k_size, pad_mode=self.padding,
  1238. stride=self.ms_stride, dilation=self.ms_dilation, mode=1, group=self.in_channel,
  1239. data_format=self.data_format
  1240. )
  1241. self.pointwise_conv = P.Conv2D(
  1242. out_channel=self.out_channel, kernel_size=(1, 1), pad_mode=self.padding, stride=(1, 1), dilation=(1, 1),
  1243. mode=1, group=1, data_format=self.data_format
  1244. )
  1245. def construct(self, x, depthwise_filters, pointwise_filters):
  1246. outputs = self.depthwise_conv(x, depthwise_filters)
  1247. outputs = self.pointwise_conv(outputs, pointwise_filters)
  1248. return outputs
  1249. class AdaptiveMeanPool1D(Cell):
  1250. def __init__(self, output_size, data_format):
  1251. super(AdaptiveMeanPool1D, self).__init__()
  1252. self.data_format, _ = preprocess_1d_format(data_format, None)
  1253. self.output_size = output_size
  1254. if self.data_format == 'NWC':
  1255. self.data_format = 'NHWC'
  1256. self.h_axis = 1
  1257. else:
  1258. self.data_format = 'NCHW'
  1259. self.h_axis = 2
  1260. self.expand_dims = P.ExpandDims()
  1261. self.squeeze = P.Squeeze(self.h_axis)
  1262. self.shape = P.Shape()
  1263. def construct(self, inputs):
  1264. if self.data_format == 'NHWC':
  1265. n, w, c = self.shape(inputs)
  1266. else:
  1267. n, c, w = self.shape(inputs)
  1268. inputs = self.expand_dims(inputs, self.h_axis)
  1269. stride = (1, w // self.output_size)
  1270. kernel = (1, w - (self.output_size - 1) * stride[1])
  1271. outputs = P.AvgPool(kernel_size=kernel, strides=stride, pad_mode='VALID', data_format=self.data_format)(inputs)
  1272. outputs = self.squeeze(outputs)
  1273. return outputs
  1274. class AdaptiveMeanPool2D(Cell):
  1275. def __init__(self, output_size, data_format):
  1276. super(AdaptiveMeanPool2D, self).__init__()
  1277. self.data_format, _ = preprocess_2d_format(data_format, None)
  1278. self.output_size = output_size
  1279. if self.data_format == 'NHWC':
  1280. self.h_axis = 1
  1281. else:
  1282. self.h_axis = 2
  1283. self.shape = P.Shape()
  1284. def construct(self, inputs):
  1285. if self.data_format == 'NHWC':
  1286. n, h, w, c = self.shape(inputs)
  1287. else:
  1288. n, c, h, w = self.shape(inputs)
  1289. out_h, out_w = self.output_size
  1290. stride_h = h // out_h
  1291. kernel_h = h - (out_h - 1) * stride_h
  1292. stride_w = w // out_w
  1293. kernel_w = w - (out_w - 1) * stride_w
  1294. outputs = P.AvgPool(
  1295. kernel_size=(kernel_h, kernel_w), strides=(stride_h, stride_w), pad_mode='VALID',
  1296. data_format=self.data_format
  1297. )(inputs)
  1298. return outputs
  1299. class AdaptiveMeanPool3D(Cell):
  1300. def __init__(self, output_size, data_format):
  1301. pass
  1302. def __call__(self, inputs):
  1303. raise NotImplementedError
  1304. class AdaptiveMaxPool1D(Cell):
  1305. def __init__(self, output_size, data_format):
  1306. super(AdaptiveMaxPool1D, self).__init__()
  1307. self.data_format, _ = preprocess_1d_format(data_format, None)
  1308. self.output_size = output_size
  1309. if self.data_format == 'NWC':
  1310. self.data_format = 'NHWC'
  1311. self.h_axis = 1
  1312. else:
  1313. self.data_format = 'NCHW'
  1314. self.h_axis = 2
  1315. self.expand_dims = P.ExpandDims()
  1316. self.squeeze = P.Squeeze(self.h_axis)
  1317. self.shape = P.Shape()
  1318. def construct(self, inputs):
  1319. if self.data_format == 'NHWC':
  1320. n, w, c = self.shape(inputs)
  1321. else:
  1322. n, c, w = self.shape(inputs)
  1323. inputs = self.expand_dims(inputs, self.h_axis)
  1324. stride = (1, w // self.output_size)
  1325. kernel = (1, w - (self.output_size - 1) * stride[1])
  1326. outputs = P.MaxPool(kernel_size=kernel, strides=stride, pad_mode='VALID', data_format=self.data_format)(inputs)
  1327. outputs = self.squeeze(outputs)
  1328. return outputs
  1329. class AdaptiveMaxPool2D(Cell):
  1330. def __init__(self, output_size, data_format):
  1331. super(AdaptiveMaxPool2D, self).__init__()
  1332. self.data_format, _ = preprocess_2d_format(data_format, None)
  1333. self.output_size = output_size
  1334. if self.data_format == 'NHWC':
  1335. self.h_axis = 1
  1336. else:
  1337. self.h_axis = 2
  1338. self.shape = P.Shape()
  1339. def construct(self, inputs):
  1340. if self.data_format == 'NHWC':
  1341. n, h, w, c = self.shape(inputs)
  1342. else:
  1343. n, c, h, w = self.shape(inputs)
  1344. out_h, out_w = self.output_size
  1345. stride_h = h // out_h
  1346. kernel_h = h - (out_h - 1) * stride_h
  1347. stride_w = w // out_w
  1348. kernel_w = w - (out_w - 1) * stride_w
  1349. outputs = P.MaxPool(
  1350. kernel_size=(kernel_h, kernel_w), strides=(stride_h, stride_w), pad_mode='VALID',
  1351. data_format=self.data_format
  1352. )(inputs)
  1353. return outputs
  1354. class AdaptiveMaxPool3D(Cell):
  1355. def __init__(self, output_size, data_format):
  1356. pass
  1357. def __call__(self, inputs):
  1358. raise NotImplementedError
  1359. class BinaryConv2D(Cell):
  1360. def __init__(self, strides, padding, data_format, dilations, out_channel, k_size, in_channel):
  1361. super(BinaryConv2D, self).__init__()
  1362. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  1363. if self.data_format is 'NHWC':
  1364. self.ms_stride = strides[1]
  1365. self.ms_dilation = dilations[1]
  1366. elif self.data_format is 'NCHW':
  1367. self.ms_stride = strides[2]
  1368. self.ms_dilation = dilations[2]
  1369. self.conv2d = P.Conv2D(
  1370. out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride,
  1371. dilation=self.ms_dilation, mode=1, group=1, data_format=self.data_format
  1372. )
  1373. @bprop_getters.register(P.Sign)
  1374. def get_bprop_Sign(self):
  1375. def bprop(x, out, dout):
  1376. grad = P.clip_by_value(dout, -1, 1)
  1377. return (grad, )
  1378. return bprop
  1379. self.sign = P.Sign()
  1380. def construct(self, inputs, filters):
  1381. filters = self.sign(filters)
  1382. outputs = self.conv2d(inputs, filters)
  1383. return outputs
  1384. class DorefaConv2D(Cell):
  1385. def __init__(self, bitW, bitA, strides, padding, data_format, dilations, out_channel, k_size, in_channel):
  1386. super(DorefaConv2D, self).__init__()
  1387. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  1388. self.bitW = ms.Tensor(bitW)
  1389. self.bitA = ms.Tensor(bitA)
  1390. if self.data_format is 'NHWC':
  1391. self.ms_stride = strides[1]
  1392. self.ms_dilation = dilations[1]
  1393. # self.transpose = P.Transpose()
  1394. elif self.data_format is 'NCHW':
  1395. self.ms_stride = strides[2]
  1396. self.ms_dilation = dilations[2]
  1397. self.conv2d = P.Conv2D(
  1398. out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride,
  1399. dilation=self.ms_dilation, mode=1, group=1
  1400. )
  1401. @bprop_getters.register(P.Round)
  1402. def get_bprop_Round(self):
  1403. def bprop(x, out, dout):
  1404. return (dout, )
  1405. return bprop
  1406. @bprop_getters.register(P.Sign)
  1407. def get_bprop_Sign(self):
  1408. def bprop(x, out, dout):
  1409. return (dout, )
  1410. return bprop
  1411. self.mimimum = P.Minimum()
  1412. self.abs = P.Abs()
  1413. self.round = P.Round()
  1414. self.reducemean = P.ReduceMean()
  1415. self.sign = P.Sign()
  1416. self.pow = P.Pow()
  1417. self.sub = P.Sub()
  1418. self.oneslike = P.OnesLike()
  1419. def cabs(self, inputs):
  1420. a = P.stop_gradient(self.oneslike(inputs))
  1421. return self.mimimum(self.abs(inputs), a)
  1422. def _quantize_dorefa(self, x, k):
  1423. n = self.sub(self.pow(2.0, k), 1)
  1424. return self.round(x * n) / n
  1425. def quantize_active(self, x, bitA):
  1426. if bitA == 32:
  1427. return x
  1428. return self._quantize_dorefa(x, bitA)
  1429. def quantize_weight(self, x, bitW, force_quantization=False):
  1430. if bitW == 32 and not force_quantization:
  1431. return x
  1432. if bitW == 1:
  1433. E = P.stop_gradient(self.reducemean(self.abs(x)))
  1434. return self.sign(x / E) * E
  1435. x = P.clip_by_value(x * 0.5 + 0.5, 0.0, 1.0)
  1436. return 2 * self._quantize_dorefa(x, bitW) - 1
  1437. def construct(self, inputs, filters):
  1438. if self.data_format == 'NHWC':
  1439. inputs = nhwc_to_nchw(inputs)
  1440. inputs = self.quantize_active(self.cabs(inputs), self.bitA)
  1441. filters = self.quantize_weight(filters, self.bitW)
  1442. outputs = self.conv2d(inputs, filters)
  1443. if self.data_format == 'NHWC':
  1444. outputs = nchw_to_nhwc(outputs)
  1445. return outputs
  1446. class rnncell(Cell):
  1447. def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act):
  1448. super(rnncell, self).__init__()
  1449. self.weight_ih = weight_ih
  1450. self.weight_hh = weight_hh
  1451. self.bias_ih = bias_ih
  1452. self.bias_hh = bias_hh
  1453. self.act_fn = P.ReLU() if act == 'relu' else P.Tanh()
  1454. self.transpose = P.Transpose()
  1455. def construct(self, input, h):
  1456. self.weight_ih = self.transpose(self.weight_ih, (1, 0))
  1457. i2h = P.matmul(input, self.weight_ih)
  1458. if self.bias_ih is not None:
  1459. i2h += self.bias_ih
  1460. self.weight_hh = self.transpose(self.weight_hh, (1, 0))
  1461. h2h = P.matmul(h, self.weight_hh)
  1462. if self.bias_hh is not None:
  1463. h2h += self.bias_hh
  1464. h = self.act_fn(i2h + h2h)
  1465. return h, h
  1466. class lstmcell(Cell):
  1467. def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
  1468. super(lstmcell, self).__init__()
  1469. self.weight_ih = weight_ih
  1470. self.weight_hh = weight_hh
  1471. self.bias_ih = bias_ih
  1472. self.bias_hh = bias_hh
  1473. self.gate_act_fn = P.Sigmoid()
  1474. self.act_fn = P.Tanh()
  1475. self.transpose = P.Transpose()
  1476. self.split = P.Split(axis=-1, output_num=4)
  1477. def construct(self, input, h, c):
  1478. self.weight_ih = self.transpose(self.weight_ih, (1, 0))
  1479. gates = P.matmul(input, self.weight_ih)
  1480. if self.bias_ih is not None:
  1481. gates += self.bias_ih
  1482. self.weight_hh = self.transpose(self.weight_hh, (1, 0))
  1483. gates += P.matmul(h, self.weight_hh)
  1484. if self.bias_hh is not None:
  1485. gates += self.bias_hh
  1486. gate_slices = self.split(gates)
  1487. i = self.gate_act_fn(gate_slices[0])
  1488. f = self.gate_act_fn(gate_slices[1])
  1489. o = self.gate_act_fn(gate_slices[3])
  1490. c = f * c + i * self.act_fn(gate_slices[2])
  1491. h = o * self.act_fn(c)
  1492. return h, h, c
  1493. class grucell(Cell):
  1494. def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
  1495. super(grucell, self).__init__()
  1496. self.weight_ih = weight_ih
  1497. self.weight_hh = weight_hh
  1498. self.bias_ih = bias_ih
  1499. self.bias_hh = bias_hh
  1500. self.gate_act_fn = P.Sigmoid()
  1501. self.act_fn = P.Tanh()
  1502. self.transpose = P.Transpose()
  1503. self.split = P.Split(axis=-1, output_num=3)
  1504. def construct(self, input, h):
  1505. self.weight_ih = self.transpose(self.weight_ih, (1, 0))
  1506. x_gates = P.matmul(input, self.weight_ih)
  1507. if self.bias_ih is not None:
  1508. x_gates += self.bias_ih
  1509. self.weight_hh = self.transpose(self.weight_hh, (1, 0))
  1510. h_gates = P.matmul(h, self.weight_hh)
  1511. if self.bias_hh is not None:
  1512. h_gates += self.bias_hh
  1513. x_r, x_z, x_c = self.split(x_gates)
  1514. h_r, h_z, h_c = self.split(h_gates)
  1515. r = self.gate_act_fn(x_r + h_r)
  1516. z = self.gate_act_fn(x_r + h_z)
  1517. c = self.act_fn(x_c + r * h_c)
  1518. h = (h - c) * z + c
  1519. return h, h
  1520. class rnnbase(Cell):
  1521. def __init__(
  1522. self,
  1523. mode,
  1524. input_size,
  1525. hidden_size,
  1526. num_layers,
  1527. bias,
  1528. batch_first,
  1529. dropout,
  1530. bidirectional,
  1531. is_train,
  1532. ):
  1533. super(rnnbase, self).__init__()
  1534. self.mode = mode
  1535. self.input_size = input_size
  1536. self.hidden_size = hidden_size
  1537. self.num_layers = num_layers
  1538. self.bidirect = 2 if bidirectional else 1
  1539. self.batch_first = batch_first
  1540. if mode == 'LSTM':
  1541. self.lstm = ms.nn.LSTM(
  1542. input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, has_bias=bias,
  1543. batch_first=batch_first, dropout=dropout, bidirectional=bidirectional
  1544. )
  1545. elif mode == 'GRU':
  1546. raise NotImplementedError
  1547. elif mode == 'RNN_TANH':
  1548. raise NotImplementedError
  1549. elif mode == 'RNN_RELU':
  1550. raise NotImplementedError
  1551. self.zeros = P.Zeros()
  1552. def construct(self, input, states):
  1553. input_shape = input.shape
  1554. input_dtype = input.dtype
  1555. if self.mode == 'LSTM':
  1556. if self.batch_first:
  1557. batch_size = input_shape[0]
  1558. else:
  1559. batch_size = input_shape[1]
  1560. if states is None:
  1561. h = self.zeros((self.bidirect * self.num_layers, batch_size, self.hidden_size), input_dtype)
  1562. c = self.zeros((self.bidirect * self.num_layers, batch_size, self.hidden_size), input_dtype)
  1563. states = (h, c)
  1564. output, (h, c) = self.lstm(input, states)
  1565. return output, (h, c)

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.