You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

thor_layer.py 50 kB

4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """layers for second order optimization"""
  16. import numpy as np
  17. import mindspore.common.dtype as mstype
  18. import mindspore.log as logger
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.common.initializer import initializer, Initializer
  21. from mindspore.ops import operations as P
  22. from mindspore.common.parameter import Parameter
  23. from mindspore._checkparam import Validator, Rel, twice
  24. from mindspore import context
  25. from mindspore.nn.cell import Cell
  26. from mindspore.nn.layer.activation import get_activation
  27. from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context
  28. from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
  29. from mindspore.context import ParallelMode
  30. from mindspore.ops.primitive import constexpr
  31. from mindspore.ops import functional as F
  32. from .basic import ClipByNorm
  33. __all__ = ['DenseThor', 'Conv2dThor', 'EmbeddingThor', 'EmbeddingLookupThor']
  34. class DenseThor(Cell):
  35. r"""
  36. The dense connected layer and saving the information needed for THOR.
  37. Applies dense connected layer for the input and saves the information A and G in the dense connected layer
  38. needed for THOR, the detail can be seen in paper: https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf
  39. This layer implements the operation as:
  40. .. math::
  41. \text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
  42. where :math:`\text{activation}` is the activation function , :math:`\text{kernel}` is a weight matrix with the same
  43. data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector
  44. with the same data type as the inputs created by the layer (only if has_bias is True).
  45. Args:
  46. in_channels (int): The number of the input channels.
  47. out_channels (int): The number of the output channels.
  48. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
  49. is same as `x`. The values of str refer to the function `initializer`. Default: 'normal'.
  50. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
  51. same as `x`. The values of str refer to the function `initializer`. Default: 'zeros'.
  52. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
  53. activation (str): activate function applied to the output of the fully connected layer, eg. 'ReLU'.
  54. Default: None.
  55. Inputs:
  56. - **x** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
  57. Outputs:
  58. Tensor of shape :math:`(N, out\_channels)`.
  59. Raises:
  60. ValueError: If the shape of `weight_init` or `bias_init` is incorrect.
  61. Supported Platforms:
  62. ``Ascend`` ``GPU``
  63. Examples:
  64. >>> x = Tensor(np.array([[1, 2, 3], [3, 4, 5]]), mindspore.float32)
  65. >>> net = nn.DenseThor(3, 4, weight_init="ones")
  66. >>> output = net(x)
  67. >>> print(output)
  68. [[ 6. 6. 6. 6.]
  69. [ 12. 12. 12. 12. ]]
  70. """
  71. def __init__(self,
  72. in_channels,
  73. out_channels,
  74. weight_init='normal',
  75. bias_init='zeros',
  76. has_bias=True,
  77. activation=None):
  78. """Initialize DenseThor."""
  79. super(DenseThor, self).__init__()
  80. self.thor = True
  81. self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name)
  82. self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name)
  83. self.has_bias = Validator.check_bool(has_bias, "has_bias", self.cls_name)
  84. if isinstance(weight_init, Tensor):
  85. if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
  86. weight_init.shape[1] != in_channels:
  87. raise ValueError(f"For '{self.cls_name}', weight init shape error. The dim of 'weight_init' should "
  88. f"be equal to 2, and the first dim should be equal to 'out_channels', and the "
  89. f"second dim should be equal to 'in_channels'. But got 'weight_init': {weight_init}, "
  90. f"'out_channels': {out_channels}, 'in_channels': {in_channels}.")
  91. self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
  92. self.bias = None
  93. if self.has_bias:
  94. if isinstance(bias_init, Tensor):
  95. if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
  96. raise ValueError(f"For '{self.cls_name}', bias init shape error. The dim of 'bias_init' should "
  97. f"be equal to 1, and the first dim should be equal to 'out_channels'. But got "
  98. f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
  99. self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
  100. self.bias_add = P.BiasAdd()
  101. self.matmul = P.MatMul(transpose_b=True)
  102. self.activation = get_activation(activation)
  103. self.activation_flag = self.activation is not None
  104. self.matrix_a = Parameter(Tensor(np.eye(in_channels).astype(np.float32)),
  105. name='matrix_a', requires_grad=False)
  106. self.matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float32)),
  107. name="matrix_g", requires_grad=False)
  108. self.shape = P.Shape()
  109. self.reshape = P.Reshape()
  110. self.transpose = P.Transpose()
  111. self.mul = P.Mul()
  112. self.is_Ascend = True
  113. self.split_dim = 128
  114. if context.get_context("device_target") == "Ascend":
  115. self._process_ascend_dense_thor(out_channels, in_channels)
  116. else:
  117. self.is_Ascend = False
  118. self.cube_matmul = P.MatMul(transpose_a=True)
  119. self.getG = P.InsertGradientOf(self.save_gradient)
  120. def _process_ascend_dense_thor(self, out_channels, in_channels):
  121. """process ascend dense thor"""
  122. self.matmul = P.MatMul(transpose_b=True)
  123. self.cube_matmul = P.CusMatMulCube(transpose_a=True)
  124. self.cast = P.Cast()
  125. self.is_nsp_layer = (out_channels == 2)
  126. def save_gradient(self, dout):
  127. """
  128. this function only for thor optimizer
  129. save_gradient
  130. """
  131. out = dout
  132. if self.is_Ascend:
  133. if not self.is_nsp_layer:
  134. shape = self.shape(dout)
  135. normalizer = self.cast(shape[0], mstype.float32)
  136. matrix_g = self.cube_matmul(dout, dout)
  137. matrix_g = self.mul(matrix_g, 1.0 / normalizer)
  138. self.matrix_g = matrix_g
  139. else:
  140. dout_shape = self.shape(dout)
  141. normalizer = dout_shape[0]
  142. matrix_g = self.cube_matmul(dout, dout)
  143. matrix_g = self.mul(matrix_g, 1.0 / normalizer)
  144. self.matrix_g = matrix_g
  145. return out
  146. def construct(self, x):
  147. if self.thor:
  148. if self.is_Ascend:
  149. inputs = self.cube_matmul(x, x)
  150. shape = self.shape(x)
  151. normalizer = self.cast(shape[0], mstype.float32)
  152. matrix_a = self.mul(inputs, 1.0 / normalizer)
  153. self.matrix_a = matrix_a
  154. else:
  155. inputs = self.cube_matmul(x, x)
  156. inputs_shape = self.shape(inputs)
  157. normalizer = inputs_shape[0]
  158. matrix_a = self.mul(inputs, 1.0 / normalizer)
  159. self.matrix_a = matrix_a
  160. x = self.matmul(x, self.weight)
  161. x = self.getG(x)
  162. else:
  163. x = self.matmul(x, self.weight)
  164. if self.has_bias:
  165. x = self.bias_add(x, self.bias)
  166. if self.activation_flag:
  167. x = self.activation(x)
  168. return x
  169. def extend_repr(self):
  170. s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
  171. if self.has_bias:
  172. s += ', has_bias={}'.format(self.has_bias)
  173. return s
  174. class _ConvThor(Cell):
  175. """
  176. Applies a N-D convolution over an input signal composed of multiple input planes.
  177. """
  178. def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode,
  179. padding, dilation, group, has_bias, weight_init, bias_init, transposed=False):
  180. """Initialize _ConvThor."""
  181. super(_ConvThor, self).__init__()
  182. self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name)
  183. self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name)
  184. self.kernel_size = kernel_size
  185. self.stride = stride
  186. self.pad_mode = pad_mode
  187. self.bias_init = bias_init
  188. if isinstance(padding, tuple):
  189. for pad in padding:
  190. Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
  191. self.padding = padding
  192. elif isinstance(padding, int):
  193. Validator.check_non_negative_int(padding, 'padding', self.cls_name)
  194. self.padding = padding
  195. else:
  196. raise TypeError(f"For '{self.cls_name}', the type of 'padding' must be int/tuple(int), but got "
  197. f"{type(padding).__name__}.")
  198. self.dilation = dilation
  199. self.group = Validator.check_positive_int(group, "group", self.cls_name)
  200. self.has_bias = has_bias
  201. self.__validate_kernel_size(kernel_size)
  202. self.__validate_stride(stride)
  203. self.__validate_dilation(dilation)
  204. if in_channels % group != 0:
  205. raise ValueError(f"For '{self.cls_name}', the 'in_channels' must be divisible by 'group', but got "
  206. f"'in_channels': {in_channels} and 'group': {group}.")
  207. if out_channels % group != 0:
  208. raise ValueError(f"For '{self.cls_name}', the 'out_channels' must be divisible by 'group', but got "
  209. f"'out_channels': {out_channels} and 'group': {group}.")
  210. if not transposed:
  211. shape = [out_channels, in_channels // group, *kernel_size]
  212. else:
  213. shape = [in_channels, out_channels // group, *kernel_size]
  214. self.weight = Parameter(initializer(weight_init, shape), name='weight')
  215. if Validator.check_bool(has_bias, "has_bias", self.cls_name):
  216. self.bias = Parameter(initializer(self.bias_init, [out_channels]), name='bias')
  217. else:
  218. if self.bias_init != 'zeros':
  219. logger.warning("Value of 'has_bias' is False, value of 'bias_init' will be ignored.")
  220. self.bias = None
  221. def __validate_kernel_size(self, kernel_size):
  222. """validate kernel size."""
  223. if (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
  224. isinstance(kernel_size[0], bool) or isinstance(kernel_size[1], bool) or \
  225. kernel_size[0] < 1 or kernel_size[1] < 1:
  226. raise ValueError(f"For '{self.cls_name}', all elements in 'kernel_size' should be int or tuple and "
  227. f"equal to or greater than 1, but got 'kernel_size': {kernel_size}.")
  228. def __validate_stride(self, stride):
  229. """validate stride."""
  230. if (not isinstance(stride[0], int)) or (not isinstance(stride[1], int)) or \
  231. isinstance(stride[0], bool) or isinstance(stride[1], bool) or stride[0] < 1 or stride[1] < 1:
  232. raise ValueError(f"For '{self.cls_name}', all elements in 'stride' should be int or tuple and "
  233. f"equal to or greater than 1, but got 'stride': {stride}.")
  234. def __validate_dilation(self, dilation):
  235. """validate dilation."""
  236. if (not isinstance(dilation[0], int)) or (not isinstance(dilation[1], int)) or \
  237. isinstance(dilation[0], bool) or isinstance(dilation[1], bool) or dilation[0] < 1 or dilation[1] < 1:
  238. raise ValueError(f"For '{self.cls_name}', all elements in 'dilation' should be int or tuple and "
  239. f"equal to or greater than 1, but got 'dilation': {dilation}.")
  240. class Conv2dThor(_ConvThor):
  241. r"""
  242. 2D convolution layer and saving the information needed for THOR.
  243. Applies a 2D convolution over an input tensor which is typically of shape :math:`(N, C_{in}, H_{in}, W_{in})`,
  244. where :math:`N` is batch size, :math:`C_{in}` is channel number, and :math:`H_{in}, W_{in})` are height and width.
  245. And saves the information A and G in the 2D convolution layer needed for THOR.
  246. The detail can be seen in paper: https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf
  247. For each batch of shape :math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as:
  248. .. math::
  249. out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,
  250. where :math:`ccor` is the cross-correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
  251. from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to the :math:`i`-th channel of the :math:`j`-th
  252. filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
  253. of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and
  254. :math:`\text{ks_w}` are the height and width of the convolution kernel. The full kernel has shape
  255. :math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number
  256. to split the input `x` in the channel dimension.
  257. If the 'pad_mode' is set to be "valid", the output height and width will be
  258. :math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
  259. (\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
  260. :math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
  261. (\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
  262. Note:
  263. For Ascend, the type of inputs should be subclass of Tensor[Float16], Tensor[Int8].
  264. For GPU, the type of inputs should be subclass of Tensor[Float32].
  265. Args:
  266. in_channels (int): The number of the input channel :math:`C_{in}`.
  267. out_channels (int): The number of the output channel :math:`C_{out}`.
  268. kernel_size (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the height
  269. and width of the 2D convolution window. Single int means that the value is not only the height, but also
  270. the width of the kernel. A tuple of 2 integers means the height and the width of the kernel respectively.
  271. stride (Union[int, tuple[int]]): The distance of kernel moving, an int number represents the height and width
  272. of movement, or a tuple of two int numbers that represent height and width of movement, respectively.
  273. Default: 1.
  274. pad_mode (str): Specifies padding mode. The optional values are
  275. "same", "valid", "pad". Default: "same".
  276. - same: Adopts the way of completion. The shape of the output will be the same as
  277. the `x`. The total number of padding will be calculated in horizontal and vertical
  278. directions and evenly distributed to top and bottom, left and right if possible. Otherwise, the
  279. last extra padding will be done from the bottom and the right side. If this mode is set, `padding`
  280. must be 0.
  281. - valid: Adopts the way of discarding. The possible largest height and width of output will be returned
  282. without padding. Extra pixels will be discarded. If this mode is set, `padding` must be 0.
  283. - pad: Implicit paddings on both sides of the input `x`. The number of `padding` will be padded to the input
  284. Tensor borders. `padding` must be greater than or equal to 0.
  285. padding (Union[int, tuple[int]]): Implicit paddings on both sides of the input `x`. If `padding` is an integer,
  286. the paddings of top, bottom, left and right are the same, equal to padding. If `padding` is a tuple
  287. with four integers, the paddings of top, bottom, left and right will be equal to padding[0],
  288. padding[1], padding[2], and padding[3] accordingly. Default: 0.
  289. dilation (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the dilation rate
  290. to use for dilated convolution. If set to be :math:`k > 1`, there will
  291. be :math:`k - 1` pixels skipped for each sampling location. Its value must
  292. be greater or equal to 1 and bounded by the height and width of the input `x`.
  293. Default: 1.
  294. group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
  295. divisible by the number of groups. If the group is equal to `in_channels` and `out_channels`,
  296. this 2D convolution layer also can be called 2D depthwise convolution layer. Default: 1.
  297. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
  298. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializes the convolution kernel.
  299. It can be a Tensor, a string, an Initializer or a number. When a string is specified,
  300. values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
  301. as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
  302. and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
  303. Initializer for more details. Default: 'normal'.
  304. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializes the bias vector. Possible
  305. Initializer and string are the same as 'weight_init'. Refer to the values of
  306. Initializer for more details. Default: 'zeros'.
  307. Inputs:
  308. - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  309. Outputs:
  310. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  311. Supported Platforms:
  312. ``Ascend`` ``GPU``
  313. Examples:
  314. >>> net = nn.Conv2dThor(120, 240, 4, has_bias=False, weight_init='normal')
  315. >>> # for Ascend
  316. >>> x = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float16)
  317. >>> print(net(x).shape)
  318. (1, 240, 1024, 640)
  319. """
  320. def __init__(self, in_channels, out_channels, kernel_size, stride=1,
  321. pad_mode='same', padding=0, dilation=1, group=1, has_bias=False,
  322. weight_init='normal', bias_init='zeros'):
  323. """Initialize Conv2dThor."""
  324. kernel_size = twice(kernel_size)
  325. stride = twice(stride)
  326. self._dilation = dilation
  327. dilation = twice(dilation)
  328. super(Conv2dThor, self).__init__(in_channels, out_channels, kernel_size,
  329. stride, pad_mode, padding, dilation, group, has_bias, weight_init, bias_init)
  330. self.conv2d = P.Conv2D(out_channel=self.out_channels, kernel_size=self.kernel_size,
  331. mode=1, pad_mode=self.pad_mode, pad=self.padding,
  332. stride=self.stride, dilation=self.dilation, group=self.group)
  333. self._init_depthwise_conv2d(weight_init)
  334. self.bias_add = P.BiasAdd()
  335. self.thor = True
  336. self.hw = kernel_size[0] * kernel_size[1]
  337. self.matrix_a_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
  338. self.matrix_g_dim = self.out_channels
  339. self.shape = P.Shape()
  340. self.reshape = P.Reshape()
  341. self.mul = P.Mul()
  342. self.cast = P.Cast()
  343. self.a_normalizer = Parameter(initializer(1, [1], mstype.float32), name="a_normalizer", requires_grad=False)
  344. self.g_normalizer = Parameter(initializer(1, [1], mstype.float32), name="g_normalizer", requires_grad=False)
  345. self.is_Ascend = True
  346. if context.get_context("device_target") == "Ascend":
  347. self._process_ascend_conv2d_thor(kernel_size, stride)
  348. else:
  349. self.is_Ascend = False
  350. self.img2col = P.Im2Col(kernel_size=kernel_size, stride=stride, pad_mode="same")
  351. self.matmul = P.MatMul(transpose_b=True)
  352. self.reduce_mean = P.ReduceMean(keep_dims=False)
  353. self.matrix_a_cov = Parameter(Tensor(np.zeros([self.matrix_a_dim, self.matrix_a_dim]).astype(np.float32)),
  354. name='matrix_a', requires_grad=False)
  355. self.matrix_g_cov = Parameter(Tensor(np.zeros([self.matrix_g_dim, self.matrix_g_dim]).astype(np.float32)),
  356. name='matrix_g', requires_grad=False)
  357. self.getG = P.InsertGradientOf(self.save_gradient)
  358. def _process_ascend_conv2d_thor(self, kernel_size, stride):
  359. """process ascend conv2d thor"""
  360. ksizes = (1, kernel_size[0], kernel_size[1], 1)
  361. strides = (1, stride[0], stride[1], 1)
  362. ksizes_tbe = (kernel_size[0], kernel_size[1])
  363. self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides)
  364. self.transpose = P.Transpose()
  365. self.reshape = P.Reshape()
  366. self.cube_matmul = P.CusMatMulCube(transpose_a=True)
  367. self.diag_block_dim = 128
  368. self.matrix_a_cov = Parameter(Tensor(np.eye(self.matrix_a_dim).astype(np.float32)),
  369. name='matrix_a', requires_grad=False)
  370. self.matrix_g_cov = Parameter(Tensor(np.eye(self.matrix_g_dim).astype(np.float32)),
  371. name='matrix_g', requires_grad=False)
  372. self.slice = P.Slice()
  373. self.im2col = P.NewIm2Col(ksizes=ksizes_tbe, strides=stride[0], padding_mode="SAME")
  374. def _init_depthwise_conv2d(self, weight_init):
  375. """Initialize depthwise conv2d op"""
  376. if context.get_context("device_target") == "Ascend" and self.group > 1:
  377. self.dilation = self._dilation
  378. Validator.check_int('group', self.group, self.in_channels, Rel.EQ, self.cls_name)
  379. Validator.check_int('group', self.group, self.out_channels, Rel.EQ, self.cls_name)
  380. self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1,
  381. kernel_size=self.kernel_size,
  382. pad_mode=self.pad_mode,
  383. pad=self.padding,
  384. stride=self.stride,
  385. dilation=self.dilation)
  386. weight_shape = [1, self.in_channels, *self.kernel_size]
  387. self.weight_init = weight_init
  388. if isinstance(weight_init, Tensor):
  389. self.weight_init = Tensor(weight_init.asnumpy().swapaxes(0, 1), weight_init.dtype)
  390. if isinstance(weight_init, Initializer):
  391. self.weight_init.shape = weight_shape
  392. self.weight = Parameter(initializer(self.weight_init, weight_shape), name='weight')
  393. def save_gradient(self, dout):
  394. """save_gradient"""
  395. out = dout
  396. if self.is_Ascend:
  397. dout_shape = self.shape(dout)
  398. dout = self.transpose(dout, (0, 2, 3, 1))
  399. dout = self.reshape(dout, (-1, dout_shape[1]))
  400. dout_shape = self.shape(dout)
  401. normalizer = dout_shape[0]
  402. matrix_g = self.cube_matmul(dout, dout)
  403. normalizer = self.cast(normalizer, mstype.float32)
  404. matrix_g = self.mul(matrix_g, 1.0 / normalizer)
  405. self.g_normalizer = normalizer
  406. self.matrix_g_cov = matrix_g
  407. else:
  408. dout = self.reduce_mean(dout, 0)
  409. dout_shape = self.shape(dout)
  410. dout = self.reshape(dout, (dout_shape[0], -1))
  411. dout_shape = self.shape(dout)
  412. normalizer = dout_shape[1]
  413. dout = self.cast(dout, mstype.float32)
  414. matrix_g = self.matmul(dout, dout)
  415. matrix_g = self.mul(matrix_g, 1.0 / normalizer)
  416. self.g_normalizer = normalizer
  417. self.matrix_g_cov = matrix_g
  418. return out
  419. def construct(self, x):
  420. if self.thor:
  421. if self.is_Ascend:
  422. matrix_a = self.im2col(x)
  423. matrix_a_shape = self.shape(matrix_a)
  424. y = matrix_a_shape[3]
  425. matrix_a = self.reshape(matrix_a, (-1, y))
  426. matrix_a_shape = self.shape(matrix_a)
  427. normalizer = matrix_a_shape[0]
  428. matrix_a = self.cube_matmul(matrix_a, matrix_a)
  429. normalizer = self.cast(normalizer, mstype.float32)
  430. matrix_a = self.mul(matrix_a, 1.0 / normalizer)
  431. self.a_normalizer = normalizer
  432. self.matrix_a_cov = matrix_a
  433. weight = self.cast(self.weight, mstype.float16)
  434. output = self.conv2d(x, weight)
  435. output = self.getG(output)
  436. else:
  437. matrix_a = self.img2col(x)
  438. matrix_a_shape = self.shape(matrix_a)
  439. matrix_a = self.reshape(matrix_a, (matrix_a_shape[0] * matrix_a_shape[1] * matrix_a_shape[2],
  440. matrix_a_shape[3], -1))
  441. matrix_a = self.reduce_mean(matrix_a, 1)
  442. matrix_a_shape = self.shape(matrix_a)
  443. normalizer = matrix_a_shape[1]
  444. matrix_a = self.cast(matrix_a, mstype.float32)
  445. matrix_a = self.matmul(matrix_a, matrix_a)
  446. matrix_a = self.mul(matrix_a, 1.0 / normalizer)
  447. self.a_normalizer = normalizer
  448. self.matrix_a_cov = matrix_a
  449. output = self.conv2d(x, self.weight)
  450. output = self.getG(output)
  451. else:
  452. if self.is_Ascend:
  453. weight = self.cast(self.weight, mstype.float16)
  454. output = self.conv2d(x, weight)
  455. else:
  456. output = self.conv2d(x, self.weight)
  457. if self.has_bias:
  458. if self.is_Ascend:
  459. bias = self.cast(self.bias, mstype.float16)
  460. output = self.bias_add(output, bias)
  461. else:
  462. output = self.bias_add(output, self.bias)
  463. return output
  464. def extend_repr(self):
  465. s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \
  466. 'pad_mode={}, padding={}, dilation={}, group={}, has_bias={}, ' \
  467. 'bias_init={}'.format(self.in_channels, self.out_channels, self.kernel_size,
  468. self.stride, self.pad_mode, self.padding, self.dilation,
  469. self.group, self.has_bias, self.bias_init)
  470. return s
  471. class EmbeddingThor(Cell):
  472. r"""
  473. A simple lookup table that stores embeddings of a fixed dictionary and size
  474. and saving the information needed for THOR.
  475. This module is often used to store word embeddings and retrieve them using
  476. indices. The input to the module is a list of indices, and the output is
  477. the corresponding word embeddings. And saves the information A and G in the dense connected layer
  478. needed for THOR, the detail can be seen in paper: https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf
  479. Note:
  480. When 'use_one_hot' is set to True, the type of the input `x` must be mindspore.int32.
  481. Args:
  482. vocab_size (int): The size of the dictionary of embeddings.
  483. embedding_size (int): The size of each embedding vector.
  484. use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: False.
  485. embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializes the embedding_table.
  486. Refer to class `initializer` for the values of string when a string is specified. Default: 'normal'.
  487. dtype (:class:`mindspore.dtype`): Data type of input `x`. Default: mindspore.float32.
  488. padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
  489. will be initialized to zero. Default: None. The feature is inactivated.
  490. Inputs:
  491. - **x** (Tensor) - Tensor of input shape :math:`(\text{batch_size}, \text{x_length})`. The elements of
  492. the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
  493. be zero.
  494. Outputs:
  495. Tensor of output shape :math:`(\text{batch_size}, \text{x_length}, \text{embedding_size})`.
  496. Supported Platforms:
  497. ``Ascend`` ``GPU``
  498. Examples:
  499. >>> net = nn.EmbeddingThor(20000, 768, True)
  500. >>> x = Tensor(np.ones([8, 128]), mindspore.int32)
  501. >>>
  502. >>> # Maps the input word IDs to word embedding.
  503. >>> output = net(x)
  504. >>> output.shape
  505. (8, 128, 768)
  506. """
  507. def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal',
  508. dtype=mstype.float32, padding_idx=None):
  509. """Initialize EmbeddingThor."""
  510. super(EmbeddingThor, self).__init__()
  511. self.vocab_size = Validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
  512. self.embedding_size = Validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
  513. Validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name)
  514. Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
  515. self.use_one_hot = use_one_hot
  516. self.dtype = dtype
  517. self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size])
  518. self.padding_idx = padding_idx
  519. if padding_idx is not None:
  520. self.padding_idx = Validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
  521. "padding_idx", self.cls_name)
  522. self.init_tensor = self.init_tensor.to_tensor().asnumpy()
  523. self.init_tensor[self.padding_idx] = 0
  524. self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
  525. self.expand = P.ExpandDims()
  526. self.reshape_flat = P.Reshape()
  527. self.shp_flat = (-1,)
  528. self.gather = P.GatherV2()
  529. self.one_hot = P.OneHot()
  530. self.on_value = Tensor(1.0, self.dtype)
  531. self.off_value = Tensor(0.0, self.dtype)
  532. self.array_mul = P.MatMul()
  533. self.reshape = P.Reshape()
  534. self.get_shp = P.Shape()
  535. self.thor = True
  536. self.matrix_a = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)),
  537. name='matrix_a', requires_grad=False)
  538. self.matrix_g = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)),
  539. name="matrix_g", requires_grad=False)
  540. self.reduce_sum = P.ReduceSum(keep_dims=False)
  541. self.getG = P.InsertGradientOf(self.save_gradient)
  542. self.cast = P.Cast()
  543. if context.get_context("device_target") == "Ascend":
  544. self.cube_matmul = P.CusMatMulCube(transpose_a=True)
  545. else:
  546. self.cube_matmul = P.MatMul(transpose_a=True)
  547. self.mul = P.Mul()
  548. def save_gradient(self, dout):
  549. """
  550. this function only for thor optimizer
  551. save_gradient
  552. """
  553. out = dout
  554. shape = self.get_shp(dout)
  555. normalizer = self.cast(shape[0], mstype.float32)
  556. matrix_g = self.cube_matmul(dout, dout)
  557. matrix_g = self.mul(matrix_g, 1.0 / normalizer)
  558. self.matrix_g = matrix_g
  559. return out
  560. def construct(self, ids):
  561. extended_ids = self.expand(ids, -1)
  562. out_shape = self.get_shp(ids) + (self.embedding_size,)
  563. flat_ids = self.reshape_flat(extended_ids, self.shp_flat)
  564. if self.use_one_hot:
  565. one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
  566. output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
  567. else:
  568. if self.thor:
  569. one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
  570. matrix_a = self.reduce_sum(one_hot_ids, 0)
  571. self.matrix_a = matrix_a
  572. output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
  573. output_for_reshape = self.getG(output_for_reshape)
  574. else:
  575. output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
  576. output = self.reshape(output_for_reshape, out_shape)
  577. return output
  578. def extend_repr(self):
  579. s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
  580. self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
  581. return s
  582. @constexpr
  583. def _make_axis_range(start, end):
  584. axis = tuple(range(start, end))
  585. return axis
  586. class EmbeddingLookupThor(Cell):
  587. r"""
  588. Returns a slice of the input tensor based on the specified indices
  589. and saving the information needed for THOR.
  590. This module has the same function as EmbeddingLookup, but additionally saves the information A and G in the
  591. embeddinglookup layer needed for THOR,
  592. the detail can be seen in paper: https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf
  593. Args:
  594. vocab_size (int): The size of the dictionary of embeddings.
  595. embedding_size (int): The size of each embedding vector.
  596. param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
  597. Refer to class `initializer` for the values of string when a string is specified.
  598. Default: 'normal'.
  599. target (str): Specifies the target where the op is executed. The value must in
  600. ['DEVICE', 'CPU']. Default: 'CPU'.
  601. slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
  602. nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
  603. manual_shapes (tuple): The accompaniment array in field slice mode.
  604. max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 or None.
  605. Default: None
  606. sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true.
  607. Default: True.
  608. vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in
  609. 'DEVICE' target. And the moment parameter of corresponding optimizer will also be set to the cache size.
  610. In addition, it should be noted that it will cost the 'DEVICE' memory, so suggests setting a reasonable
  611. value to avoid insufficient memory.
  612. Inputs:
  613. - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
  614. Outputs:
  615. Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
  616. Raises:
  617. ValueError: If `target` is neither 'CPU' nor 'DEVICE'.
  618. ValueError: If `slice_mode` is not one of 'batch_slice' or 'field_slice' or
  619. 'table_row_slice' or 'table_column_slice'.
  620. ValueError: If `sparse` is False and `target` is 'CPU'.
  621. ValueError: If `slice_mode` is 'field_slice' and `manual_shapes` is None.
  622. TypeError: If `vocab_size` or `embedding_size` or `vocab_cache_size` is not an int.
  623. TypeError: If `sparse` is not a bool or `manual_shapes` is not a tuple.
  624. ValueError: If `vocab_size` or `embedding_size` is less than 1.
  625. ValueError: If `vocab_cache_size` is less than 0.
  626. Supported Platforms:
  627. ``Ascend``
  628. Examples:
  629. >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
  630. >>> result = nn.EmbeddingLookup(4,2)(input_indices)
  631. >>> print(result.shape)
  632. (2, 2, 2)
  633. """
  634. BATCH_SLICE = "batch_slice"
  635. FIELD_SLICE = "field_slice"
  636. TABLE_ROW_SLICE = "table_row_slice"
  637. TABLE_COLUMN_SLICE = "table_column_slice"
  638. def __init__(self, vocab_size, embedding_size, param_init='normal',
  639. target='CPU', slice_mode='batch_slice', manual_shapes=None,
  640. max_norm=None, sparse=True, vocab_cache_size=0):
  641. super(EmbeddingLookupThor, self).__init__()
  642. Validator.check_value_type('sparse', sparse, [bool], self.cls_name)
  643. self.vocab_size = Validator.check_positive_int(vocab_size, 'vocab_size', self.cls_name)
  644. self.vocab_cache_size = Validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size', self.cls_name)
  645. self.target = target
  646. self.sparse = sparse
  647. self.cache_enable = self.vocab_cache_size > 0
  648. self.forward_unique = False
  649. self.dtype = mstype.float16
  650. if target not in ('CPU', 'DEVICE'):
  651. raise ValueError(f"For '{self.cls_name}', the 'target' should be one of values in ('CPU', 'DEVICE'), "
  652. f"but got {target}.")
  653. if not sparse and target == 'CPU':
  654. raise ValueError(f"For '{self.cls_name}', embedding_lookup must be sparse when 'target' is CPU, but got "
  655. f"'sparse': {sparse}, 'target': {target}.")
  656. if sparse:
  657. self.gatherv2 = P.SparseGatherV2()
  658. else:
  659. self.gatherv2 = P.Gather()
  660. self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
  661. enable_ps = _get_ps_context("enable_ps")
  662. if enable_ps:
  663. self._process_vocab_cache(slice_mode)
  664. self.embedding_size = Validator.check_positive_int(embedding_size, 'embedding_size', self.cls_name)
  665. self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size],
  666. mstype.float16), name='embedding_table')
  667. parallel_mode = _get_parallel_mode()
  668. is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
  669. self.gather_revert = P.Gather()
  670. self.reshape_first = P.Reshape()
  671. self.reshape = P.Reshape()
  672. self.unique = P.Unique()
  673. self.shape = P.Shape()
  674. if is_auto_parallel:
  675. self.unique = P.Unique().shard(((1,),))
  676. if self.cache_enable and enable_ps:
  677. self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size)
  678. if is_auto_parallel:
  679. self.unique.add_prim_attr('cache_enable', True)
  680. indices_shape_size = 2
  681. if slice_mode == "field_slice" and is_auto_parallel:
  682. if not manual_shapes:
  683. raise ValueError(f"For '{self.cls_name}', the 'manual_shapes' should not be none "
  684. f"when 'slice_mode' is 'field_slice'.")
  685. if not isinstance(manual_shapes, tuple):
  686. raise TypeError(f"For '{self.cls_name}', the type of 'manual_shapes' must be tuple(int), but got "
  687. f"type {type(manual_shapes).__name__}.")
  688. for dim in manual_shapes:
  689. Validator.check_positive_int(dim, 'manual shape dim', self.cls_name)
  690. self.gatherv2.add_prim_attr("manual_split", manual_shapes)
  691. self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
  692. self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
  693. self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
  694. elif slice_mode == "table_row_slice" and is_auto_parallel:
  695. full_batch = _get_full_batch()
  696. if (target == 'DEVICE' and not full_batch) or (self.cache_enable and enable_ps and sparse):
  697. indices_shape_size = 1
  698. self.gather_revert.shard(((1, 1), (get_group_size(),)))
  699. self.forward_unique = True
  700. indices_strategy = (1,) * indices_shape_size
  701. self.gatherv2.shard(((get_group_size(), 1), indices_strategy))
  702. self.embeddinglookup.shard(((get_group_size(), 1), indices_strategy))
  703. elif slice_mode == "table_column_slice" and is_auto_parallel:
  704. if target == 'DEVICE':
  705. indices_shape_size = 1
  706. self.gather_revert.shard(((1, get_group_size()), (1,)))
  707. self.forward_unique = True
  708. indices_strategy = (1,) * indices_shape_size
  709. self.gatherv2.shard(((1, get_group_size()), indices_strategy))
  710. self.embeddinglookup.shard(((1, get_group_size()), indices_strategy))
  711. elif slice_mode == "batch_slice" and is_auto_parallel:
  712. indices_strategy = [get_group_size()]
  713. indices_strategy.extend([1] * (indices_shape_size - 1))
  714. indices_strategy = tuple(indices_strategy)
  715. self.gatherv2.shard(((1, 1), indices_strategy))
  716. self.embeddinglookup.shard(((1, 1), indices_strategy))
  717. else:
  718. if is_auto_parallel:
  719. raise ValueError(f"For '{self.cls_name}', the 'slice_mode' should be one of values in "
  720. f"['field_slice', 'table_row_slice', 'table_column_slice', 'batch_slice'], "
  721. f"but got 'slice_mode': {slice_mode}")
  722. if self.cache_enable and not enable_ps:
  723. if parallel_mode != ParallelMode.STAND_ALONE:
  724. raise ValueError(f"For '{self.cls_name}', the 'parallel_mode' should be equal to "
  725. f"'ParallelMode.STAND_ALONE', but got {parallel_mode}.")
  726. self._set_cache_enable()
  727. self.embedding_table.unique = self.forward_unique
  728. self.max_norm = max_norm
  729. if self.max_norm is not None:
  730. self.max_norm = Validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
  731. self.max_norm = Tensor(self.max_norm, dtype=mstype.float16)
  732. self.thor = True
  733. self.matrix_a = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)),
  734. name='matrix_a', requires_grad=False)
  735. self.matrix_g = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)),
  736. name="matrix_g", requires_grad=False)
  737. self.reduce_sum = P.ReduceSum(keep_dims=False)
  738. self.getG = P.InsertGradientOf(self.save_gradient)
  739. self.cast = P.Cast()
  740. self.cube_matmul = P.MatMul(transpose_a=True)
  741. self.mul = P.Mul()
  742. self.on_value = Tensor(1.0, self.dtype)
  743. self.off_value = Tensor(0.0, self.dtype)
  744. self.one_hot = P.OneHot()
  745. def save_gradient(self, dout):
  746. """
  747. this function only for thor optimizer
  748. save_gradient
  749. """
  750. out = dout
  751. shape = self.shape(dout)
  752. normalizer = self.cast(shape[0], mstype.float16)
  753. dout = self.reshape(dout, (-1, self.embedding_size))
  754. matrix_g = self.cube_matmul(dout, dout)
  755. matrix_g = self.mul(matrix_g, 1.0 / normalizer)
  756. matrix_g = self.cast(matrix_g, mstype.float16)
  757. self.matrix_g = matrix_g
  758. return out
  759. def _set_cache_enable(self):
  760. """EmbeddingLookup cache check for not ps env, which is only support 'ascend'."""
  761. if self.target != 'DEVICE':
  762. raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid "
  763. f"only when 'target' is 'DEVICE', but got 'target': {self.target}.")
  764. if not self.sparse:
  765. raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid "
  766. f"only when 'sparse' is true, but got 'sparse': {self.sparse}.")
  767. if context.get_context("device_target") != 'Ascend':
  768. raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid "
  769. f"only when 'device_target' is 'Ascend', but got {context.get_context('device_target')}.")
  770. logger.info("EmbeddingLookup cache enable takes effect.")
  771. self.forward_unique = True
  772. self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU')
  773. self.unique.add_prim_attr('cache_enable', True)
  774. self.embedding_table.cache_enable = self.cache_enable
  775. self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size)
  776. self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU')
  777. def _process_vocab_cache(self, slice_mode):
  778. """PS embeddingLookup cache check and process."""
  779. self.cache_enable = False
  780. if self.vocab_cache_size > 0:
  781. if self.target == 'CPU':
  782. logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
  783. "current target is CPU, so it will be ignored.")
  784. return
  785. enable_ps = _get_ps_context("enable_ps")
  786. if not enable_ps:
  787. logger.warning(
  788. "The configuration of 'vocab_cache_size' is valid only in parameter server trainning "
  789. "mode, current mode is not parameter server trainning mode, so it will be ignored.")
  790. return
  791. parallel_mode = _get_parallel_mode()
  792. is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
  793. if is_auto_parallel:
  794. rank_size = get_group_size()
  795. rank_id = get_rank()
  796. full_batch = _get_full_batch()
  797. if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"):
  798. raise ValueError(f"For '{self.cls_name}', the embeddingLookup cache of parameter server parallel "
  799. f"only be used in 'full_batch' and 'table_row_slice' parallel strategy, but got "
  800. f"'full_batch': {full_batch}, 'slice_mode': {slice_mode}.")
  801. self.vocab_cache_size = self.vocab_cache_size * rank_size
  802. _set_rank_id(rank_id)
  803. self.cache_enable = True
  804. if _is_role_worker():
  805. self.vocab_size = self.vocab_cache_size
  806. if context.get_context("enable_sparse") != self.sparse:
  807. raise ValueError(f"For '{self.cls_name}', the 'sparse' must be equal to the 'enable_sparse' "
  808. f"in context setting in parameter server cache mode, but got 'sparse': "
  809. f"{self.sparse}, 'enable_sparse': {context.get_context('enable_sparse')}.")
  810. def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
  811. """PS embeddingLookup cache enable set."""
  812. self.embedding_table.cache_enable = True
  813. self.embedding_table.is_param_ps = True
  814. _set_cache_enable(True)
  815. if self.sparse:
  816. self.forward_unique = True
  817. if _is_role_worker():
  818. _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
  819. def construct(self, indices):
  820. if self.target == "CPU":
  821. out = self.embeddinglookup(self.embedding_table, indices, 0)
  822. else:
  823. if self.thor:
  824. if self.forward_unique:
  825. shp = self.shape(indices) + (self.embedding_size,)
  826. indices_flatten = self.reshape_first(indices, (-1,))
  827. unique_id, unique_idx = self.unique(indices_flatten)
  828. one_hot_ids = self.one_hot(indices_flatten, self.vocab_size, self.on_value, self.off_value)
  829. matrix_a = self.reduce_sum(one_hot_ids, 0)
  830. matrix_a = self.cast(matrix_a, mstype.float16)
  831. self.matrix_a = matrix_a
  832. weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
  833. out = self.getG(weight_unique)
  834. weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
  835. out = self.reshape(weight_flatten, shp)
  836. else:
  837. indices_flatten = self.reshape_first(indices, (-1,))
  838. one_hot_ids = self.one_hot(indices_flatten, self.vocab_size, self.on_value, self.off_value)
  839. matrix_a = self.reduce_sum(one_hot_ids, 0)
  840. matrix_a = self.cast(matrix_a, mstype.float16)
  841. self.matrix_a = matrix_a
  842. out = self.gatherv2(self.embedding_table, indices, 0)
  843. out = self.getG(out)
  844. else:
  845. if self.forward_unique:
  846. shp = self.shape(indices) + (self.embedding_size,)
  847. indices_flatten = self.reshape_first(indices, (-1,))
  848. unique_id, unique_idx = self.unique(indices_flatten)
  849. weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
  850. weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
  851. out = self.reshape(weight_flatten, shp)
  852. else:
  853. out = self.gatherv2(self.embedding_table, indices, 0)
  854. if self.max_norm is not None:
  855. axis = _make_axis_range(F.rank(indices), F.rank(out))
  856. clip_by_norm = ClipByNorm(axis)
  857. out = clip_by_norm(out, self.max_norm)
  858. return out