You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_indexed_slices.py 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. @File : test_indexed_slices.py
  17. @Author:
  18. @Date : 2020-06-08
  19. @Desc : test mindspore indexed_slices's operation
  20. """
  21. import numpy as np
  22. import mindspore as ms
  23. import mindspore.nn as nn
  24. from mindspore.ops import composite as C
  25. from mindspore.ops import functional as F
  26. from mindspore.ops import operations as P
  27. from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
  28. from mindspore.ops.primitive import constexpr
  29. from mindspore.ops._grad.grad_base import bprop_getters
  30. from mindspore import Tensor, IndexedSlices, context
  31. from mindspore.common.parameter import Parameter, ParameterTuple
  32. from mindspore.common import dtype as mstype
  33. from mindspore._checkparam import Validator as validator
  34. from mindspore._checkparam import Rel
  35. from mindspore.nn import Optimizer
  36. from mindspore.nn import TrainOneStepCell, WithLossCell
  37. from mindspore.nn.optim import Momentum
  38. from mindspore.train import Model
  39. from ....dataset_mock import MindData
  40. context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
  41. reduce_sum = P.ReduceSum()
  42. unsorted_segment_sum = P.UnsortedSegmentSum()
  43. transpose = P.Transpose()
  44. shape_op = P.Shape()
  45. reshape = P.Reshape()
  46. size_op = P.Size()
  47. invert_permutation = P.InvertPermutation()
  48. logical_and = P.LogicalAnd()
  49. def get_axis(x):
  50. shape = shape_op(x)
  51. length = F.tuple_len(shape)
  52. perm = F.make_range(0, length)
  53. return perm
  54. class MSELoss(nn.Cell):
  55. def __init__(self):
  56. super(MSELoss, self).__init__()
  57. self.reduce_sum = P.ReduceSum()
  58. self.square = P.Square()
  59. self.reduce_mean = P.ReduceMean()
  60. def construct(self, data, label):
  61. diff = data - label
  62. return self.reduce_mean(self.square(diff), get_axis(diff))
  63. class MindDataSet(MindData):
  64. def __init__(self, dataset_types, dataset_shapes):
  65. super(MindDataSet, self).__init__(size=2, batch_size=32,
  66. np_types=dataset_types,
  67. output_shapes=dataset_shapes,
  68. input_indexs=(0, 1))
  69. def __next__(self):
  70. if self._size < self._iter_num:
  71. raise StopIteration
  72. self._iter_num += 1
  73. lst = []
  74. for shape_, type_ in zip(self._output_shapes, self._np_types):
  75. lst.append(Tensor(np.ones(shape_).astype(type_)))
  76. return tuple(lst)
  77. @constexpr
  78. def _generate_shape_index(out_shape, indices_shape, axis):
  79. out_rank = len(out_shape)
  80. ind_rank = len(indices_shape)
  81. if axis < 0:
  82. axis += out_rank - ind_rank + 1
  83. perm_part1 = tuple(range(axis, axis + ind_rank))
  84. index = tuple(range(out_rank))
  85. perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
  86. return perm
  87. @constexpr
  88. def _generate_inverse_index(x_shape, axis):
  89. x_rank = len(x_shape)
  90. index = tuple(range(x_rank))
  91. if axis < 0:
  92. axis += x_rank
  93. perm = index[1:1 + axis] + (0,) + index[1 + axis:]
  94. return perm
  95. class MySparseGatherV2(P.GatherV2):
  96. """
  97. For test
  98. """
  99. @bprop_getters.register(MySparseGatherV2)
  100. def get_bprop_sparse_gather_v2(self):
  101. """Generate bprop for MySparseGatherV2"""
  102. def bprop(x, indices, axis, out, dout):
  103. x_shp = shape_op(x)
  104. if axis == 0:
  105. indices_size = (size_op(indices),)
  106. x_tail_shp = x_shp[1:]
  107. values_shape = indices_size + x_tail_shp
  108. values = reshape(dout, values_shape)
  109. indices = reshape(indices, indices_size)
  110. return IndexedSlices(indices, values, x_shp), zeros_like(indices), zeros_like(axis)
  111. if F.rank(dout) == 0:
  112. dout = P.ExpandDims()(dout, -1)
  113. if F.rank(indices) == 0:
  114. indices = P.ExpandDims()(indices, -1)
  115. out_shp = shape_op(dout)
  116. ind_shp = shape_op(indices)
  117. # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
  118. perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
  119. values_transpose = transpose(dout, perm_1)
  120. params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
  121. # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
  122. perm_2 = _generate_inverse_index(x_shp, axis)
  123. params_grad = transpose(params_grad, perm_2)
  124. return params_grad, zeros_like(indices), zeros_like(axis)
  125. return bprop
  126. adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
  127. @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
  128. "Tensor", "Tensor", "Tensor", "IndexedSlices", "Bool")
  129. def _update_run_op_for_map_indexed_slices(beta1, beta2, eps, lr, weight_decay_tensor, param,
  130. m, v, gradient, decay_flag):
  131. return gradient.values()
  132. @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
  133. "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
  134. def _update_run_op_for_map_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param,
  135. m, v, gradient, decay_flag):
  136. op_mul = P.Mul()
  137. op_square = P.Square()
  138. op_sqrt = P.Sqrt()
  139. op_cast = P.Cast()
  140. op_reshape = P.Reshape()
  141. op_shape = P.Shape()
  142. param_fp32 = op_cast(param, mstype.float32)
  143. m_fp32 = op_cast(m, mstype.float32)
  144. v_fp32 = op_cast(v, mstype.float32)
  145. gradient_fp32 = op_cast(gradient, mstype.float32)
  146. next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
  147. next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
  148. - beta2, op_square(gradient_fp32))
  149. update = next_m / (op_sqrt(next_v) + eps)
  150. if decay_flag:
  151. update = update + op_mul(weight_decay_tensor, param_fp32)
  152. update_with_lr = op_mul(lr, update)
  153. next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
  154. next_v = F.depend(next_v, F.assign(param, next_param))
  155. next_v = F.depend(next_v, F.assign(m, next_m))
  156. next_v = F.depend(next_v, F.assign(v, next_v))
  157. return next_v
  158. def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
  159. """Check the type of inputs."""
  160. validator.check_value_type("beta1", beta1, [float], prim_name)
  161. validator.check_value_type("beta2", beta2, [float], prim_name)
  162. validator.check_value_type("eps", eps, [float], prim_name)
  163. validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
  164. validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
  165. validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
  166. validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
  167. validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
  168. class AdamWeightDecaySparse(Optimizer):
  169. def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
  170. decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
  171. super(AdamWeightDecaySparse, self).__init__(learning_rate, params)
  172. if self.is_group:
  173. raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
  174. _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
  175. self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
  176. self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
  177. self.eps = Tensor(np.array([eps]).astype(np.float32))
  178. self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
  179. self.params = self.parameters
  180. self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
  181. self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
  182. self.decay_flag = tuple(decay_filter(x) for x in self.params)
  183. self.map = C.Map()
  184. def construct(self, gradients):
  185. lr = self.get_lr()
  186. updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr,
  187. self.weight_decay_tensor),
  188. self.params, self.moments1, self.moments2, gradients, self.decay_flag)
  189. return updated_velocity
  190. def test_indexed_slices_make_indexed_slices():
  191. class MakeIndexedSlices(nn.Cell):
  192. def __init__(self):
  193. super(MakeIndexedSlices, self).__init__()
  194. self.dense_shape = (3, 4)
  195. def construct(self, indices, values):
  196. ret = (IndexedSlices(indices, values, self.dense_shape),)
  197. return ret[0]
  198. indices = Tensor([1, 2])
  199. values = Tensor([[0, 0], [1, 2]], dtype=ms.float32)
  200. MakeIndexedSlices()(indices, values)
  201. def test_indexed_slices_attr():
  202. class IndexedSlicesGetAttr(nn.Cell):
  203. def __init__(self):
  204. super(IndexedSlicesGetAttr, self).__init__()
  205. self.dense_shape = (3, 4)
  206. def construct(self, indices, values):
  207. x = IndexedSlices(indices, values, self.dense_shape)
  208. return x.values(), x.indices(), x.dense_shape()
  209. indices = Tensor([0])
  210. values = Tensor([[1, 2]], dtype=ms.float32)
  211. IndexedSlicesGetAttr()(indices, values)
  212. def test_indexed_slices_sparse_gatherv2_grad_all():
  213. grad_all = C.GradOperation('get_all', get_all=True)
  214. class GradWrap(nn.Cell):
  215. def __init__(self, network):
  216. super(GradWrap, self).__init__()
  217. self.network = network
  218. def construct(self, x, y):
  219. grad = grad_all(self.network)(x, y)
  220. return grad, grad[0], grad[1]
  221. class SparseGatherV2(nn.Cell):
  222. def __init__(self):
  223. super(SparseGatherV2, self).__init__()
  224. self.sparse_gatherv2 = MySparseGatherV2()
  225. self.axis = 0
  226. def construct(self, params, indices):
  227. return self.sparse_gatherv2(params, indices, self.axis)
  228. params = Tensor(np.ones([3, 1, 2]).astype(np.int32))
  229. indices = Tensor(np.array([0, 1]).astype(np.int32))
  230. GradWrap(SparseGatherV2())(params, indices)
  231. def test_indexed_slices_sparse_gatherv2_grad_with_pram():
  232. grad_by_list = C.GradOperation('get_by_list', get_by_list=True)
  233. class GradWrap(nn.Cell):
  234. def __init__(self, network):
  235. super(GradWrap, self).__init__()
  236. self.network = network
  237. self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
  238. def construct(self, x):
  239. weights = self.weights
  240. grad = grad_by_list(self.network, weights)(x)
  241. x = grad[0]
  242. return x, x.values(), x.indices(), x.dense_shape()
  243. class SparseGatherV2(nn.Cell):
  244. def __init__(self):
  245. super(SparseGatherV2, self).__init__()
  246. self.sparse_gatherv2 = MySparseGatherV2()
  247. self.axis = 0
  248. self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), name="params")
  249. def construct(self, indices):
  250. return self.sparse_gatherv2(self.params, indices, self.axis)
  251. indices = Tensor(np.array([0, 1]).astype(np.int32))
  252. network = GradWrap(SparseGatherV2())
  253. network(indices)
  254. def test_indexed_slices_env_get():
  255. class Loss(nn.Cell):
  256. def __init__(self):
  257. super(Loss, self).__init__()
  258. def construct(self, base, target):
  259. return base
  260. class NetWithSparseGatherV2(nn.Cell):
  261. def __init__(self):
  262. super(NetWithSparseGatherV2, self).__init__()
  263. self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1")
  264. self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
  265. self.gatherv2 = MySparseGatherV2()
  266. self.axis = 0
  267. def construct(self, indices):
  268. return self.gatherv2(self.w1, indices, self.axis) * self.w2
  269. inputs = Tensor(np.array([0, 1]).astype(np.int32))
  270. label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
  271. net = NetWithSparseGatherV2()
  272. net.set_train()
  273. loss = Loss()
  274. optimizer = AdamWeightDecaySparse(net.trainable_params())
  275. net_with_loss = WithLossCell(net, loss)
  276. train_network = TrainOneStepCell(net_with_loss, optimizer)
  277. train_network(inputs, label)
  278. def test_indexed_slices_model_train():
  279. class Net(nn.Cell):
  280. def __init__(self, in_features, out_features):
  281. super(Net, self).__init__()
  282. self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
  283. self.add = P.TensorAdd()
  284. self.cast = P.Cast()
  285. self.flag = True
  286. def construct(self, inputs, label):
  287. x = self.add(inputs, self.weight)
  288. if self.flag:
  289. x = self.cast(x, mstype.float32)
  290. return x
  291. dataset_types = (np.float32, np.float32)
  292. dataset_shapes = ((16, 16), (16, 16))
  293. dataset = MindDataSet(dataset_types, dataset_shapes)
  294. net = Net(16, 16)
  295. net.set_train()
  296. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  297. model = Model(net, optimizer=optimizer)
  298. model.train(2, dataset, dataset_sink_mode=False)