You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

thor_layer.py 33 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """layers for second order optimization"""
  16. import numpy as np
  17. import mindspore.common.dtype as mstype
  18. from mindspore.common.tensor import Tensor
  19. from mindspore.common.initializer import initializer, Initializer
  20. from mindspore.ops import operations as P
  21. from mindspore.common.parameter import Parameter
  22. from mindspore._checkparam import Validator, Rel, twice
  23. from mindspore import context
  24. from mindspore.nn.cell import Cell
  25. from mindspore.nn.layer.activation import get_activation
  26. __all__ = ['Dense_Thor', 'Conv2d_Thor', 'Embedding_Thor']
  27. class Dense_Thor(Cell):
  28. r"""
  29. The dense connected layer.
  30. Applies dense connected layer for the input. This layer implements the operation as:
  31. .. math::
  32. \text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
  33. where :math:`\text{activation}` is the activation function passed as the activation
  34. argument (if passed in), :math:`\text{kernel}` is a weight matrix with the same
  35. data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector
  36. with the same data type as the inputs created by the layer (only if has_bias is True).
  37. Args:
  38. in_channels (int): The number of channels in the input space.
  39. out_channels (int): The number of channels in the output space.
  40. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
  41. is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
  42. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
  43. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
  44. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
  45. activation (str): activate function applied to the output of the fully connected layer, eg. 'ReLU'.
  46. Default: None.
  47. Raises:
  48. ValueError: If weight_init or bias_init shape is incorrect.
  49. Inputs:
  50. - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
  51. Outputs:
  52. Tensor of shape :math:`(N, out\_channels)`.
  53. Examples:
  54. >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
  55. >>> net = nn.Dense(3, 4)
  56. >>> net(input)
  57. [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
  58. [ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
  59. """
  60. def __init__(self,
  61. in_channels,
  62. out_channels,
  63. weight_init='normal',
  64. bias_init='zeros',
  65. has_bias=True,
  66. activation=None):
  67. super(Dense_Thor, self).__init__()
  68. self.thor = True
  69. self.in_channels = Validator.check_positive_int(in_channels)
  70. self.out_channels = Validator.check_positive_int(out_channels)
  71. self.has_bias = Validator.check_bool(has_bias)
  72. if isinstance(weight_init, Tensor):
  73. if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
  74. weight_init.shape[1] != in_channels:
  75. raise ValueError("Weight init shape error.")
  76. self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
  77. self.bias = None
  78. if self.has_bias:
  79. if isinstance(bias_init, Tensor):
  80. if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
  81. raise ValueError("Bias init shape error.")
  82. self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
  83. self.bias_add = P.BiasAdd()
  84. self.matmul = P.MatMul(transpose_b=True)
  85. self.activation = get_activation(activation)
  86. self.activation_flag = self.activation is not None
  87. self.matrix_A = Parameter(Tensor(np.zeros([in_channels, in_channels]).astype(np.float32)),
  88. name='matrix_A', requires_grad=False)
  89. self.shape = P.Shape()
  90. self.reshape = P.Reshape()
  91. self.transpose = P.Transpose()
  92. self.mul = P.Mul()
  93. self.is_Ascend = True
  94. if context.get_context("device_target") == "Ascend":
  95. if out_channels == 1001:
  96. self.matrix_G = Parameter(Tensor(np.zeros([1024, 1024]).astype(np.float32)),
  97. name='matrix_G', requires_grad=False)
  98. self.pad = P.Pad(((0, 23), (0, 23)))
  99. self.pad1 = P.Pad(((0, 7), (0, 7)))
  100. self.slice = P.Slice()
  101. self.add = P.TensorAdd()
  102. else:
  103. self.matrix_G = Parameter(Tensor(np.eye(out_channels).astype(np.float32)),
  104. name="matrix_G", requires_grad=False)
  105. self.abs = P.Abs()
  106. self.reduce_max = P.ReduceMax(keep_dims=False)
  107. self.neg = P.Neg()
  108. self.reduce_sum = P.ReduceSum()
  109. self.matmul = P.MatMul(transpose_b=True)
  110. self.cube_matmul = P.CusMatMulCube(transpose_a=True)
  111. self.cast = P.Cast()
  112. self.is_nsp_layer = (out_channels == 2)
  113. else:
  114. self.is_Ascend = False
  115. self.matrix_G = Parameter(Tensor(np.eye(out_channels).astype(np.float32)),
  116. name="matrix_G", requires_grad=False)
  117. self.cube_matmul = P.MatMul(transpose_a=True)
  118. self.getG = P.InsertGradientOf(self.save_gradient)
  119. def save_gradient(self, dout):
  120. """
  121. this function only for thor optimizer
  122. save_gradient
  123. """
  124. out = dout
  125. if self.is_Ascend:
  126. if not self.is_nsp_layer:
  127. shape = self.shape(dout)
  128. normalizer = self.cast(shape[0], mstype.float32)
  129. matrix_G = self.cube_matmul(dout, dout)
  130. matrix_G = self.mul(matrix_G, 1.0 / normalizer)
  131. if self.out_channels == 1001:
  132. matrix_G = P.Pad(((0, 23), (0, 23)))(matrix_G)
  133. self.matrix_G = matrix_G
  134. else:
  135. dout_shape = self.shape(dout)
  136. normalizer = dout_shape[0]
  137. matrix_G = self.cube_matmul(dout, dout)
  138. matrix_G = self.mul(matrix_G, 1.0 / normalizer)
  139. self.matrix_G = matrix_G
  140. return out
  141. def construct(self, x):
  142. if self.thor:
  143. if self.is_Ascend:
  144. inputs = self.cube_matmul(x, x)
  145. shape = self.shape(x)
  146. normalizer = self.cast(shape[0], mstype.float32)
  147. matrix_A = self.mul(inputs, 1.0 / normalizer)
  148. self.matrix_A = matrix_A
  149. else:
  150. inputs = self.cube_matmul(x, x)
  151. inputs_shape = self.shape(inputs)
  152. normalizer = inputs_shape[0]
  153. matrix_A = self.mul(inputs, 1.0 / normalizer)
  154. self.matrix_A = matrix_A
  155. x = self.matmul(x, self.weight)
  156. x = self.getG(x)
  157. else:
  158. x = self.matmul(x, self.weight)
  159. if self.has_bias:
  160. x = self.bias_add(x, self.bias)
  161. if self.activation_flag:
  162. x = self.activation(x)
  163. return x
  164. def extend_repr(self):
  165. s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
  166. if self.has_bias:
  167. s += ', has_bias={}'.format(self.has_bias)
  168. # if self.activation_flag:
  169. # s += ', activation={}'.format(self.activation)
  170. return s
  171. class _Conv(Cell):
  172. """
  173. Applies a N-D convolution over an input signal composed of several input planes.
  174. """
  175. def __init__(self,
  176. in_channels,
  177. out_channels,
  178. kernel_size,
  179. stride,
  180. pad_mode,
  181. padding,
  182. dilation,
  183. group,
  184. has_bias,
  185. weight_init,
  186. bias_init,
  187. transposed=False):
  188. super(_Conv, self).__init__()
  189. self.in_channels = Validator.check_positive_int(in_channels)
  190. self.out_channels = Validator.check_positive_int(out_channels)
  191. self.kernel_size = kernel_size
  192. self.stride = stride
  193. self.pad_mode = pad_mode
  194. # self.weight_init = weight_init
  195. self.bias_init = bias_init
  196. if isinstance(padding, int):
  197. Validator.check_non_negative_int(padding, 'padding', self.cls_name)
  198. self.padding = padding
  199. elif isinstance(padding, tuple):
  200. for pad in padding:
  201. Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
  202. self.padding = padding
  203. else:
  204. raise TypeError("padding type must be int/tuple(int) cannot be {}!".format(type(padding)))
  205. self.dilation = dilation
  206. self.group = Validator.check_positive_int(group)
  207. self.has_bias = has_bias
  208. if (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
  209. isinstance(kernel_size[0], bool) or isinstance(kernel_size[1], bool) or \
  210. kernel_size[0] < 1 or kernel_size[1] < 1:
  211. raise ValueError("Attr 'kernel_size' of 'Conv2D' Op passed "
  212. + str(self.kernel_size) + ", should be a int or tuple and equal to or greater than 1.")
  213. if (not isinstance(stride[0], int)) or (not isinstance(stride[1], int)) or \
  214. isinstance(stride[0], bool) or isinstance(stride[1], bool) or stride[0] < 1 or stride[1] < 1:
  215. raise ValueError("Attr 'stride' of 'Conv2D' Op passed "
  216. + str(self.stride) + ", should be a int or tuple and equal to or greater than 1.")
  217. if (not isinstance(dilation[0], int)) or (not isinstance(dilation[1], int)) or \
  218. isinstance(dilation[0], bool) or isinstance(dilation[1], bool) or dilation[0] < 1 or dilation[1] < 1:
  219. raise ValueError("Attr 'dilation' of 'Conv2D' Op passed "
  220. + str(self.dilation) + ", should be a int or tuple and equal to or greater than 1.")
  221. if in_channels % group != 0:
  222. raise ValueError("Attr 'in_channels' of 'Conv2D' Op must be divisible by "
  223. "attr 'group' of 'Conv2D' Op.")
  224. if out_channels % group != 0:
  225. raise ValueError("Attr 'out_channels' of 'Conv2D' Op must be divisible by "
  226. "attr 'group' of 'Conv2D' Op.")
  227. if transposed:
  228. shape = [in_channels, out_channels // group, *kernel_size]
  229. else:
  230. shape = [out_channels, in_channels // group, *kernel_size]
  231. self.weight = Parameter(initializer(weight_init, shape), name='weight')
  232. if Validator.check_bool(has_bias):
  233. self.bias = Parameter(initializer(self.bias_init, [out_channels]), name='bias')
  234. else:
  235. if self.bias_init != 'zeros':
  236. logger.warning("Value of 'has_bias' is False, value of 'bias_init' will be ignored.")
  237. self.bias = None
  238. def construct(self, *inputs):
  239. """Must be overridden by all subclasses."""
  240. raise NotImplementedError
  241. class Conv2d_Thor(_Conv):
  242. r"""
  243. 2D convolution layer.
  244. Applies a 2D convolution over an input tensor which is typically of shape :math:`(N, C_{in}, H_{in}, W_{in})`,
  245. where :math:`N` is batch size, :math:`C_{in}` is channel number, and :math:`H_{in}, W_{in})` are height and width.
  246. For each batch of shape :math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as:
  247. .. math::
  248. out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,
  249. where :math:`ccor` is the cross-correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
  250. from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to the :math:`i`-th channel of the :math:`j`-th
  251. filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
  252. of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and
  253. :math:`\text{ks_w}` are the height and width of the convolution kernel. The full kernel has shape
  254. :math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number
  255. to split the input in the channel dimension.
  256. If the 'pad_mode' is set to be "valid", the output height and width will be
  257. :math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
  258. (\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
  259. :math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
  260. (\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
  261. The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition
  262. <http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
  263. Args:
  264. in_channels (int): The number of input channel :math:`C_{in}`.
  265. out_channels (int): The number of output channel :math:`C_{out}`.
  266. kernel_size (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the height
  267. and width of the 2D convolution window. Single int means the value is for both the height and the width of
  268. the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
  269. width of the kernel.
  270. stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
  271. the height and width of movement are both strides, or a tuple of two int numbers that
  272. represent height and width of movement respectively. Default: 1.
  273. pad_mode (str): Specifies padding mode. The optional values are
  274. "same", "valid", "pad". Default: "same".
  275. - same: Adopts the way of completion. The height and width of the output will be the same as
  276. the input. The total number of padding will be calculated in horizontal and vertical
  277. directions and evenly distributed to top and bottom, left and right if possible. Otherwise, the
  278. last extra padding will be done from the bottom and the right side. If this mode is set, `padding`
  279. must be 0.
  280. - valid: Adopts the way of discarding. The possible largest height and width of output will be returned
  281. without padding. Extra pixels will be discarded. If this mode is set, `padding`
  282. must be 0.
  283. - pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input
  284. Tensor borders. `padding` must be greater than or equal to 0.
  285. padding (Union[int, tuple[int]]): Implicit paddings on both sides of the input. If `padding` is one integer,
  286. the paddings of top, bottom, left and right are the same, equal to padding. If `padding` is a tuple
  287. with four integers, the paddings of top, bottom, left and right will be equal to padding[0],
  288. padding[1], padding[2], and padding[3] accordingly. Default: 0.
  289. dilation (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the dilation rate
  290. to use for dilated convolution. If set to be :math:`k > 1`, there will
  291. be :math:`k - 1` pixels skipped for each sampling location. Its value must
  292. be greater or equal to 1 and bounded by the height and width of the
  293. input. Default: 1.
  294. group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
  295. divisible by the number of groups. If the group is equal to `in_channels` and `out_channels`,
  296. this 2D convolution layer also can be called 2D depthwise convolution layer. Default: 1.
  297. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
  298. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
  299. It can be a Tensor, a string, an Initializer or a number. When a string is specified,
  300. values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
  301. as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
  302. and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
  303. Initializer for more details. Default: 'normal'.
  304. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
  305. Initializer and string are the same as 'weight_init'. Refer to the values of
  306. Initializer for more details. Default: 'zeros'.
  307. Inputs:
  308. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  309. Outputs:
  310. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  311. Examples:
  312. >>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
  313. >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
  314. >>> net(input).shape
  315. (1, 240, 1024, 640)
  316. """
  317. def __init__(self,
  318. in_channels,
  319. out_channels,
  320. kernel_size,
  321. stride=1,
  322. pad_mode='same',
  323. padding=0,
  324. dilation=1,
  325. group=1,
  326. has_bias=False,
  327. weight_init='normal',
  328. bias_init='zeros'):
  329. kernel_size = twice(kernel_size)
  330. stride = twice(stride)
  331. self._dilation = dilation
  332. dilation = twice(dilation)
  333. super(Conv2d_Thor, self).__init__(
  334. in_channels,
  335. out_channels,
  336. kernel_size,
  337. stride,
  338. pad_mode,
  339. padding,
  340. dilation,
  341. group,
  342. has_bias,
  343. weight_init,
  344. bias_init)
  345. self.conv2d = P.Conv2D(out_channel=self.out_channels,
  346. kernel_size=self.kernel_size,
  347. mode=1,
  348. pad_mode=self.pad_mode,
  349. pad=self.padding,
  350. stride=self.stride,
  351. dilation=self.dilation,
  352. group=self.group)
  353. self._init_depthwise_conv2d(weight_init)
  354. self.bias_add = P.BiasAdd()
  355. self.thor = True
  356. self.hw = kernel_size[0] * kernel_size[1]
  357. self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
  358. self.matrix_G_dim = self.out_channels
  359. self.shape = P.Shape()
  360. self.reshape = P.Reshape()
  361. self.mul = P.Mul()
  362. self.cast = P.Cast()
  363. self.A_normalizer = Parameter(initializer(0, [1], mstype.float32), name="A_normalizer", requires_grad=False)
  364. self.G_normalizer = Parameter(initializer(0, [1], mstype.float32), name="G_normalizer", requires_grad=False)
  365. self.is_Ascend = True
  366. if context.get_context("device_target") == "Ascend":
  367. ksizes = (1, kernel_size[0], kernel_size[1], 1)
  368. strides = (1, stride[0], stride[1], 1)
  369. self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides)
  370. self.cube_matmul = P.CusMatMulCube(transpose_a=True)
  371. self.transpose02314 = P.CusTranspose02314()
  372. dampingA_dim = self.matrix_A_dim
  373. self.diag_block_dim = 128
  374. if (self.matrix_A_dim % self.diag_block_dim) != 0 and self.matrix_A_dim > self.diag_block_dim:
  375. dampingA_dim = (self.matrix_A_dim // self.diag_block_dim + 1) * self.diag_block_dim
  376. dampingG_dim = self.matrix_G_dim
  377. if (self.matrix_G_dim % self.diag_block_dim) != 0 and self.matrix_G_dim > self.diag_block_dim:
  378. dampingG_dim = (self.matrix_G_dim // self.diag_block_dim + 1) * self.diag_block_dim
  379. self.matrix_A_cov = Parameter(Tensor(np.zeros([dampingA_dim, dampingA_dim]).astype(np.float32)),
  380. name='matrix_A', requires_grad=False)
  381. self.matrix_G_cov = Parameter(Tensor(np.zeros([dampingG_dim, dampingG_dim]).astype(np.float32)),
  382. name='matrix_G', requires_grad=False)
  383. self.channels_slice_flag = False
  384. self.C0 = 16
  385. if self.in_channels % self.C0 != 0:
  386. self.channels_slice_flag = True
  387. self.padA_flag = False
  388. if (self.matrix_A_dim // self.diag_block_dim) * self.diag_block_dim != self.matrix_A_dim \
  389. and self.matrix_A_dim > self.diag_block_dim:
  390. self.padA_flag = True
  391. pad_dim = self.diag_block_dim - self.matrix_A_dim % self.diag_block_dim
  392. self.padA = P.Pad(((0, pad_dim), (0, pad_dim)))
  393. self.slice = P.Slice()
  394. else:
  395. self.is_Ascend = False
  396. self.img2col = P.Im2Col(kernel_size=kernel_size, stride=stride, pad_mode="same")
  397. self.matmul = P.MatMul(transpose_b=True)
  398. self.reduce_mean = P.ReduceMean(keep_dims=False)
  399. self.matrix_A_cov = Parameter(Tensor(np.zeros([self.matrix_A_dim, self.matrix_A_dim]).astype(np.float32)),
  400. name='matrix_A', requires_grad=False)
  401. self.matrix_G_cov = Parameter(Tensor(np.zeros([self.matrix_G_dim, self.matrix_G_dim]).astype(np.float32)),
  402. name='matrix_G', requires_grad=False)
  403. self.getG = P.InsertGradientOf(self.save_gradient)
  404. def _init_depthwise_conv2d(self, weight_init):
  405. """Initialize depthwise conv2d op"""
  406. if context.get_context("device_target") == "Ascend" and self.group > 1:
  407. self.dilation = self._dilation
  408. Validator.check_integer('group', self.group, self.in_channels, Rel.EQ)
  409. Validator.check_integer('group', self.group, self.out_channels, Rel.EQ)
  410. self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1,
  411. kernel_size=self.kernel_size,
  412. pad_mode=self.pad_mode,
  413. pad=self.padding,
  414. stride=self.stride,
  415. dilation=self.dilation)
  416. weight_shape = [1, self.in_channels, *self.kernel_size]
  417. self.weight_init = weight_init
  418. if isinstance(weight_init, Tensor):
  419. self.weight_init = Tensor(weight_init.asnumpy().swapaxes(0, 1), weight_init.dtype)
  420. if isinstance(weight_init, Initializer):
  421. self.weight_init.shape = weight_shape
  422. self.weight = Parameter(initializer(self.weight_init, weight_shape), name='weight')
  423. def save_gradient(self, dout):
  424. """save_gradient"""
  425. out = dout
  426. if self.is_Ascend:
  427. dout = self.transpose02314(dout)
  428. dout_shape = self.shape(dout)
  429. normalizer = dout_shape[0]
  430. matrix_G = self.cube_matmul(dout, dout)
  431. normalizer = self.cast(normalizer, mstype.float32)
  432. matrix_G = self.mul(matrix_G, 1.0 / normalizer)
  433. self.G_normalizer = normalizer
  434. self.matrix_G_cov = matrix_G
  435. else:
  436. dout = self.reduce_mean(dout, 0)
  437. dout_shape = self.shape(dout)
  438. dout = self.reshape(dout, (dout_shape[0], -1))
  439. dout_shape = self.shape(dout)
  440. normalizer = dout_shape[1]
  441. dout = self.cast(dout, mstype.float32)
  442. matrix_G = self.matmul(dout, dout)
  443. matrix_G = self.mul(matrix_G, 1.0 / normalizer)
  444. self.G_normalizer = normalizer
  445. self.matrix_G_cov = matrix_G
  446. return out
  447. def construct(self, x):
  448. if self.thor:
  449. matrix_A = self.img2col(x)
  450. matrix_A_shape = self.shape(matrix_A)
  451. if self.is_Ascend:
  452. normalizer = matrix_A_shape[0]
  453. matrix_A = self.cube_matmul(matrix_A, matrix_A)
  454. if self.channels_slice_flag:
  455. matrix_A = self.reshape(matrix_A, (self.hw, self.C0, self.hw, self.C0))
  456. matrix_A = self.slice(matrix_A, (0, 0, 0, 0),
  457. (self.hw, self.in_channels, self.hw, self.in_channels))
  458. matrix_A = self.reshape(matrix_A, (self.matrix_A_dim, self.matrix_A_dim))
  459. normalizer = self.cast(normalizer, mstype.float32)
  460. matrix_A = self.mul(matrix_A, 1.0 / normalizer)
  461. if self.padA_flag:
  462. matrix_A = self.padA(matrix_A)
  463. self.A_normalizer = normalizer
  464. self.matrix_A_cov = matrix_A
  465. else:
  466. matrix_A = self.reshape(matrix_A, (matrix_A_shape[0] * matrix_A_shape[1] * matrix_A_shape[2],
  467. matrix_A_shape[3], -1))
  468. matrix_A = self.reduce_mean(matrix_A, 1)
  469. matrix_A_shape = self.shape(matrix_A)
  470. normalizer = matrix_A_shape[1]
  471. matrix_A = self.cast(matrix_A, mstype.float32)
  472. matrix_A = self.matmul(matrix_A, matrix_A)
  473. matrix_A = self.mul(matrix_A, 1.0 / normalizer)
  474. self.A_normalizer = normalizer
  475. self.matrix_A_cov = matrix_A
  476. output = self.conv2d(x, self.weight)
  477. output = self.getG(output)
  478. else:
  479. output = self.conv2d(x, self.weight)
  480. if self.has_bias:
  481. output = self.bias_add(output, self.bias)
  482. return output
  483. def extend_repr(self):
  484. s = 'input_channels={}, output_channels={}, kernel_size={},' \
  485. 'stride={}, pad_mode={}, padding={}, dilation={}, ' \
  486. 'group={}, has_bias={},' \
  487. 'weight_init={}, bias_init={}'.format(
  488. self.in_channels,
  489. self.out_channels,
  490. self.kernel_size,
  491. self.stride,
  492. self.pad_mode,
  493. self.padding,
  494. self.dilation,
  495. self.group,
  496. self.has_bias,
  497. self.weight_init,
  498. self.bias_init)
  499. return s
  500. class Embedding_Thor(Cell):
  501. r"""
  502. A simple lookup table that stores embeddings of a fixed dictionary and size.
  503. This module is often used to store word embeddings and retrieve them using
  504. indices. The input to the module is a list of indices, and the output is
  505. the corresponding word embeddings.
  506. Note:
  507. When 'use_one_hot' is set to True, the type of the input must be mindspore.int32.
  508. Args:
  509. vocab_size (int): Size of the dictionary of embeddings.
  510. embedding_size (int): The size of each embedding vector.
  511. use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: False.
  512. embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
  513. Refer to class `initializer` for the values of string when a string
  514. is specified. Default: 'normal'.
  515. dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
  516. padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
  517. will be initialized to zero. Default: None. The feature is inactivated.
  518. Inputs:
  519. - **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The elements of
  520. the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
  521. be zero.
  522. Outputs:
  523. Tensor of shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`.
  524. Examples:
  525. >>> net = nn.Embedding(20000, 768, True)
  526. >>> input_data = Tensor(np.ones([8, 128]), mindspore.int32)
  527. >>>
  528. >>> # Maps the input word IDs to word embedding.
  529. >>> output = net(input_data)
  530. >>> output.shape
  531. (8, 128, 768)
  532. """
  533. def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal',
  534. dtype=mstype.float32, padding_idx=None):
  535. super(Embedding_Thor, self).__init__()
  536. self.vocab_size = Validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
  537. self.embedding_size = Validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
  538. Validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name)
  539. Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
  540. self.use_one_hot = use_one_hot
  541. self.dtype = dtype
  542. self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size])
  543. self.padding_idx = padding_idx
  544. if padding_idx is not None:
  545. self.padding_idx = Validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
  546. "padding_idx", self.cls_name)
  547. self.init_tensor = self.init_tensor.to_tensor().asnumpy()
  548. self.init_tensor[self.padding_idx] = 0
  549. self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
  550. self.expand = P.ExpandDims()
  551. self.reshape_flat = P.Reshape()
  552. self.shp_flat = (-1,)
  553. self.gather = P.GatherV2()
  554. self.one_hot = P.OneHot()
  555. self.on_value = Tensor(1.0, self.dtype)
  556. self.off_value = Tensor(0.0, self.dtype)
  557. self.array_mul = P.MatMul()
  558. self.reshape = P.Reshape()
  559. self.get_shp = P.Shape()
  560. self.thor = True
  561. self.matrix_A = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)),
  562. name='matrix_A', requires_grad=False)
  563. self.matrix_G = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)),
  564. name="matrix_G", requires_grad=False)
  565. self.reduce_sum = P.ReduceSum(keep_dims=False)
  566. self.getG = P.InsertGradientOf(self.save_gradient)
  567. self.cast = P.Cast()
  568. if context.get_context("device_target") == "Ascend":
  569. self.cube_matmul = P.CusMatMulCube(transpose_a=True)
  570. else:
  571. self.cube_matmul = P.MatMul(transpose_a=True)
  572. self.mul = P.Mul()
  573. def save_gradient(self, dout):
  574. """
  575. this function only for thor optimizer
  576. save_gradient
  577. """
  578. out = dout
  579. shape = self.get_shp(dout)
  580. normalizer = self.cast(shape[0], mstype.float32)
  581. matrix_G = self.cube_matmul(dout, dout)
  582. matrix_G = self.mul(matrix_G, 1.0 / normalizer)
  583. self.matrix_G = matrix_G
  584. return out
  585. def construct(self, ids):
  586. extended_ids = self.expand(ids, -1)
  587. out_shape = self.get_shp(ids) + (self.embedding_size,)
  588. flat_ids = self.reshape_flat(extended_ids, self.shp_flat)
  589. if self.use_one_hot:
  590. one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
  591. output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
  592. else:
  593. if self.thor:
  594. one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
  595. matrix_A = self.reduce_sum(one_hot_ids, 0)
  596. self.matrix_A = matrix_A
  597. output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
  598. output_for_reshape = self.getG(output_for_reshape)
  599. else:
  600. output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
  601. output = self.reshape(output_for_reshape, out_shape)
  602. return output
  603. def extend_repr(self):
  604. s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
  605. self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
  606. return s