You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_indexed_slices.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. @File : test_indexed_slices.py
  17. @Author:
  18. @Date : 2020-06-08
  19. @Desc : test mindspore indexed_slices's operation
  20. """
  21. import numpy as np
  22. import mindspore as ms
  23. import mindspore.nn as nn
  24. from mindspore.ops import composite as C
  25. from mindspore.ops import functional as F
  26. from mindspore.ops import operations as P
  27. from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
  28. from mindspore.ops.primitive import constexpr
  29. from mindspore.ops._grad.grad_base import bprop_getters
  30. from mindspore import Tensor, IndexedSlices, context
  31. from mindspore.common.parameter import Parameter, ParameterTuple
  32. from mindspore.common import dtype as mstype
  33. from mindspore._checkparam import Validator as validator
  34. from mindspore._checkparam import Rel
  35. from mindspore.nn import Optimizer
  36. from mindspore.nn import TrainOneStepCell, WithLossCell
  37. context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
  38. reduce_sum = P.ReduceSum()
  39. unsorted_segment_sum = P.UnsortedSegmentSum()
  40. transpose = P.Transpose()
  41. shape_op = P.Shape()
  42. reshape = P.Reshape()
  43. size_op = P.Size()
  44. invert_permutation = P.InvertPermutation()
  45. logical_and = P.LogicalAnd()
  46. @constexpr
  47. def _generate_shape_index(out_shape, indices_shape, axis):
  48. out_rank = len(out_shape)
  49. ind_rank = len(indices_shape)
  50. if axis < 0:
  51. axis += out_rank - ind_rank + 1
  52. perm_part1 = tuple(range(axis, axis + ind_rank))
  53. index = tuple(range(out_rank))
  54. perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
  55. return perm
  56. @constexpr
  57. def _generate_inverse_index(x_shape, axis):
  58. x_rank = len(x_shape)
  59. index = tuple(range(x_rank))
  60. if axis < 0:
  61. axis += x_rank
  62. perm = index[1:1 + axis] + (0,) + index[1 + axis:]
  63. return perm
  64. class MySparseGatherV2(P.GatherV2):
  65. """
  66. For test
  67. """
  68. @bprop_getters.register(MySparseGatherV2)
  69. def get_bprop_sparse_gather_v2(self):
  70. """Generate bprop for MySparseGatherV2"""
  71. def bprop(x, indices, axis, out, dout):
  72. x_shp = shape_op(x)
  73. if axis == 0:
  74. indices_size = (size_op(indices),)
  75. x_tail_shp = x_shp[1:]
  76. values_shape = indices_size + x_tail_shp
  77. values = reshape(dout, values_shape)
  78. indices = reshape(indices, indices_size)
  79. return IndexedSlices(indices, values, x_shp), zeros_like(indices), zeros_like(axis)
  80. if F.rank(dout) == 0:
  81. dout = P.ExpandDims()(dout, -1)
  82. if F.rank(indices) == 0:
  83. indices = P.ExpandDims()(indices, -1)
  84. out_shp = shape_op(dout)
  85. ind_shp = shape_op(indices)
  86. # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
  87. perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
  88. values_transpose = transpose(dout, perm_1)
  89. params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
  90. # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
  91. perm_2 = _generate_inverse_index(x_shp, axis)
  92. params_grad = transpose(params_grad, perm_2)
  93. return params_grad, zeros_like(indices), zeros_like(axis)
  94. return bprop
  95. adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
  96. @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
  97. "Tensor", "Tensor", "Tensor", "IndexedSlices", "Bool")
  98. def _update_run_op_for_map_indexed_slices(beta1, beta2, eps, lr, weight_decay_tensor, param,
  99. m, v, gradient, decay_flag):
  100. return gradient.values()
  101. @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
  102. "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
  103. def _update_run_op_for_map_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param,
  104. m, v, gradient, decay_flag):
  105. op_mul = P.Mul()
  106. op_square = P.Square()
  107. op_sqrt = P.Sqrt()
  108. op_cast = P.Cast()
  109. op_reshape = P.Reshape()
  110. op_shape = P.Shape()
  111. param_fp32 = op_cast(param, mstype.float32)
  112. m_fp32 = op_cast(m, mstype.float32)
  113. v_fp32 = op_cast(v, mstype.float32)
  114. gradient_fp32 = op_cast(gradient, mstype.float32)
  115. next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
  116. next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
  117. - beta2, op_square(gradient_fp32))
  118. update = next_m / (op_sqrt(next_v) + eps)
  119. if decay_flag:
  120. update = update + op_mul(weight_decay_tensor, param_fp32)
  121. update_with_lr = op_mul(lr, update)
  122. next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
  123. next_v = F.depend(next_v, F.assign(param, next_param))
  124. next_v = F.depend(next_v, F.assign(m, next_m))
  125. next_v = F.depend(next_v, F.assign(v, next_v))
  126. return next_v
  127. def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
  128. """Check the type of inputs."""
  129. validator.check_value_type("beta1", beta1, [float], prim_name)
  130. validator.check_value_type("beta2", beta2, [float], prim_name)
  131. validator.check_value_type("eps", eps, [float], prim_name)
  132. validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
  133. validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
  134. validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
  135. validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
  136. validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
  137. class AdamWeightDecaySparse(Optimizer):
  138. def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
  139. decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
  140. super(AdamWeightDecaySparse, self).__init__(learning_rate, params)
  141. if self.is_group:
  142. raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
  143. _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
  144. self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
  145. self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
  146. self.eps = Tensor(np.array([eps]).astype(np.float32))
  147. self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
  148. self.params = self.parameters
  149. self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
  150. self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
  151. self.decay_flag = tuple(decay_filter(x) for x in self.params)
  152. self.map = C.Map()
  153. def construct(self, gradients):
  154. lr = self.get_lr()
  155. updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr,
  156. self.weight_decay_tensor),
  157. self.params, self.moments1, self.moments2, gradients, self.decay_flag)
  158. return updated_velocity
  159. def test_indexed_slices_make_indexed_slices():
  160. class MakeIndexedSlices(nn.Cell):
  161. def __init__(self):
  162. super(MakeIndexedSlices, self).__init__()
  163. self.dense_shape = (3, 4)
  164. def construct(self, indices, values):
  165. ret = (IndexedSlices(indices, values, self.dense_shape),)
  166. return ret[0]
  167. indices = Tensor([[0, 0], [1, 2]])
  168. values = Tensor([1, 2], dtype=ms.float32)
  169. MakeIndexedSlices()(indices, values)
  170. def test_indexed_slices_attr():
  171. class IndexedSlicesGetAttr(nn.Cell):
  172. def __init__(self):
  173. super(IndexedSlicesGetAttr, self).__init__()
  174. self.dense_shape = (3, 4)
  175. def construct(self, indices, values):
  176. x = IndexedSlices(indices, values, self.dense_shape)
  177. return x.values(), x.indices(), x.dense_shape()
  178. indices = Tensor([[0, 0], [1, 2]])
  179. values = Tensor([1, 2], dtype=ms.float32)
  180. IndexedSlicesGetAttr()(indices, values)
  181. def test_indexed_slices_sparse_gatherv2_grad_all():
  182. grad_all = C.GradOperation('get_all', get_all=True)
  183. class GradWrap(nn.Cell):
  184. def __init__(self, network):
  185. super(GradWrap, self).__init__()
  186. self.network = network
  187. def construct(self, x, y):
  188. grad = grad_all(self.network)(x, y)
  189. return grad, grad[0], grad[1]
  190. class SparseGatherV2(nn.Cell):
  191. def __init__(self):
  192. super(SparseGatherV2, self).__init__()
  193. self.sparse_gatherv2 = MySparseGatherV2()
  194. self.axis = 0
  195. def construct(self, params, indices):
  196. return self.sparse_gatherv2(params, indices, self.axis)
  197. params = Tensor(np.ones([3, 1, 2]).astype(np.int32))
  198. indices = Tensor(np.array([0, 1]).astype(np.int32))
  199. GradWrap(SparseGatherV2())(params, indices)
  200. def test_indexed_slices_sparse_gatherv2_grad_with_pram():
  201. grad_by_list = C.GradOperation('get_by_list', get_by_list=True)
  202. class GradWrap(nn.Cell):
  203. def __init__(self, network):
  204. super(GradWrap, self).__init__()
  205. self.network = network
  206. self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
  207. def construct(self, x):
  208. weights = self.weights
  209. grad = grad_by_list(self.network, weights)(x)
  210. x = grad[0]
  211. return x, x.values(), x.indices(), x.dense_shape()
  212. class SparseGatherV2(nn.Cell):
  213. def __init__(self):
  214. super(SparseGatherV2, self).__init__()
  215. self.sparse_gatherv2 = MySparseGatherV2()
  216. self.axis = 0
  217. self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), name="params")
  218. def construct(self, indices):
  219. return self.sparse_gatherv2(self.params, indices, self.axis)
  220. indices = Tensor(np.array([0, 1]).astype(np.int32))
  221. network = GradWrap(SparseGatherV2())
  222. network(indices)
  223. def test_indexed_slices_env_get():
  224. class Loss(nn.Cell):
  225. def __init__(self):
  226. super(Loss, self).__init__()
  227. def construct(self, base, target):
  228. return base
  229. class NetWithSparseGatherV2(nn.Cell):
  230. def __init__(self):
  231. super(NetWithSparseGatherV2, self).__init__()
  232. self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1")
  233. self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
  234. self.gatherv2 = MySparseGatherV2()
  235. self.axis = 0
  236. def construct(self, indices):
  237. return self.gatherv2(self.w1, indices, self.axis) * self.w2
  238. inputs = Tensor(np.array([0, 1]).astype(np.int32))
  239. label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
  240. net = NetWithSparseGatherV2()
  241. net.set_train()
  242. loss = Loss()
  243. optimizer = AdamWeightDecaySparse(net.trainable_params())
  244. net_with_loss = WithLossCell(net, loss)
  245. train_network = TrainOneStepCell(net_with_loss, optimizer)
  246. train_network(inputs, label)