From 92880788f385abba0b81a6af37814f2ecdfb1732 Mon Sep 17 00:00:00 2001 From: wuxuejian Date: Mon, 22 Jun 2020 09:40:57 +0800 Subject: [PATCH] add aicpu embeddinglookup move embeddinglookup to the internal --- mindspore/ops/_grad/grad_array_ops.py | 26 +++++ mindspore/ops/_op_impl/aicpu/__init__.py | 1 + .../ops/_op_impl/aicpu/embedding_lookup.py | 102 ++++++++++++++++++ mindspore/ops/operations/_inner_ops.py | 70 ++++++++++++ mindspore/ops/operations/array_ops.py | 55 +++------- tests/st/ops/ascend/test_embedding_lookup.py | 42 ++++++++ .../python/parallel/test_embeddinglookup.py | 18 ++-- 7 files changed, 266 insertions(+), 48 deletions(-) create mode 100644 mindspore/ops/_op_impl/aicpu/embedding_lookup.py create mode 100644 tests/st/ops/ascend/test_embedding_lookup.py diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 72d8d74f46..b7b7af8082 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -17,6 +17,7 @@ from .. import operations as P from ..operations import _grad_ops as G +from ..operations import _inner_ops as inner from ..composite.multitype_ops.zeros_like_impl import zeros_like from .. import functional as F from .grad_base import bprop_getters @@ -188,6 +189,31 @@ def get_bprop_tile(self): return bprop +@bprop_getters.register(inner.EmbeddingLookup) +def get_bprop_embedding_lookup(self): + """Generate bprop for EmbeddingLookup""" + host_sub = P.Sub().add_prim_attr('primitive_target', 'CPU') + host_reshape = P.Reshape().add_prim_attr('primitive_target', 'CPU') + def bprop_sparse(x, indices, offset, reduce_scatter_flag, split_num, out, dout): + x_shp = shape_op(x) + if reduce_scatter_flag is True: + elu_grad = G.EmbeddingLookupCommGrad() + actual_dout = elu_grad(dout, split_num) + else: + actual_dout = dout + new_indices = host_sub(indices - offset) + # Reshape the 'new_indices' + new_indices_shape_changed = (size_op(new_indices),) + new_indices = host_reshape(new_indices, new_indices_shape_changed) + # Reshape the 'actual_dout' + x_shp_tail = x_shp[1:] + actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail + actual_dout = host_reshape(actual_dout, actual_dout_shape_changed) + return (new_indices, actual_dout, x_shp), zeros_like(new_indices), zeros_like(axis), \ + zeros_like(reduce_scatter_flag), zeros_like(split_num) + return bprop_sparse + + @bprop_getters.register(P.Transpose) def get_bprop_transpose(self): """Generate bprop for Transpose""" diff --git a/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/ops/_op_impl/aicpu/__init__.py index 4709714de0..48df11c23a 100644 --- a/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/ops/_op_impl/aicpu/__init__.py @@ -14,6 +14,7 @@ """aicpu ops""" from .init_data_set_queue import _init_data_set_queue_aicpu +from .embedding_lookup import _embedding_lookup_aicpu from .dropout_genmask import _dropout_genmask_aicpu from .get_next import _get_next_aicpu from .print_tensor import _print_aicpu diff --git a/mindspore/ops/_op_impl/aicpu/embedding_lookup.py b/mindspore/ops/_op_impl/aicpu/embedding_lookup.py new file mode 100644 index 0000000000..8eecc5145d --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/embedding_lookup.py @@ -0,0 +1,102 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""EmbeddingLookup op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +embeddingLookup_op_info = AiCPURegOp("EmbeddingLookup") \ + .fusion_type("OPAQUE") \ + .input(0, "params", "required") \ + .input(1, "indices", "required") \ + .input(2, "offset", "required") \ + .output(0, "output", "required") \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.F64_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.I32_Default, \ + DataType.I32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I8_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.F64_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I8_Default, DataType.I64_Default, \ + DataType.I32_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I64_Default, \ + DataType.I32_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I64_Default, \ + DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, \ + DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.I64_Default, \ + DataType.I32_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.I64_Default, \ + DataType.I32_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.I64_Default, \ + DataType.I32_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.I64_Default, \ + DataType.I32_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.I64_Default, \ + DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, \ + DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I64_Default, \ + DataType.I32_Default, DataType.F64_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.I64_Default, \ + DataType.I32_Default, DataType.BOOL_Default) \ + .get_op_info() + +@op_info_register(embeddingLookup_op_info) +def _embedding_lookup_aicpu(): + """EmbeddingLookup AiCPU register""" + return diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 38f399316a..2f9970eb0c 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -96,3 +96,73 @@ class ExtractImagePatches(PrimitiveWithInfer): """infer dtype""" validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) return input_x + + +class EmbeddingLookup(PrimitiveWithInfer): + """ + Returns a slice of input tensor based on the specified indices. + + This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has three more inputs: + `offset`, `reduce_scatter_flag` and `split_num`. This primitive runs on the host instead of devices. + + Inputs: + - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. + The Tensor slice, instead of the entire Tensor. + - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. + Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`, + and the exceeding part will be filled with 0 in the output. + - **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices + are equal to `input_indices` minus `offset`. + - **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not. + Only constant value is allowed. + - **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable + is used only if `reduce_scatter_flag` is True. Only constant value is allowed. + + + Outputs: + Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. + + Examples: + >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) + >>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32) + >>> offset = 4 + >>> reduce_scatter_flag = False + >>> split_num = 1 + >>> out = P.EmbeddingLookup()(input_params, input_indices, offset, reduce_scatter_flag, split_num) + [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] + """ + @prim_attr_register + def __init__(self): + """init index_select""" + self.__setattr_flag__ = True + self.init_prim_io_names(inputs=['params', 'indices', 'offset', 'reduce_scatter_flag', 'split_num'], + outputs=['output']) + self.add_prim_attr('primitive_target', 'CPU') + + def __infer__(self, params, indices, offset, reduce_scatter_flag=False, split_num=2): + validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) + validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) + validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) + validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name) + if split_num['value'] < 1: + raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num) + params_shp = params['shape'] + out_shape = indices['shape'] + params_shp[1:] + if reduce_scatter_flag is None: + raise ValueError("The value of 'reduce_scatter_flag' is None.") + reduce_scatter_flag_value = reduce_scatter_flag['value'] + if split_num is None: + raise ValueError("The value of 'split_num_value' is None.") + split_num_value = split_num['value'] + if reduce_scatter_flag_value is True: + # Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by + # (split_num * 8) + if out_shape[0] % (split_num_value * 8) != 0: + raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." % + (out_shape[0], (split_num_value * 8))) + # After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8 + out_shape[0] = out_shape[0] // 8 + out = {'shape': out_shape, + 'dtype': params['dtype'], + 'value': None} + return out diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index d53f92c2a3..79a92ed7c8 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -577,64 +577,43 @@ class Range(PrimitiveWithInfer): class EmbeddingLookup(PrimitiveWithInfer): """ Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar - functionality as GatherV2, but has three more inputs: `offset`, `reduce_scatter_flag` and `split_num`. + functionality as GatherV2, but has one more inputs: `offset`. + This primitive runs on the acipu devices. Inputs: - - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. + - **params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The Tensor slice, instead of the entire Tensor. - - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. - Specifies the indices of elements of the original Tensor. Must be in the range - `[0, input_param.shape()[axis])`. - - **axis** (int) - Specifies the dimension index to gather indices. - - **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices - are equal to `input_indices` minus `offset`. - - **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not. - - **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable - is used only if `reduce_scatter_flag` is True. + - **indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. + Specifies the indices of elements of the original Tensor. Values can be out of range of `params`, + and the exceeding part will be filled with 0 in the output. + The indices to do lookup operation whose data type should be mindspore.int32 or mindspore.int64. + - **offset** (int) - Specifies the offset value of this `params` slice. Thus the real indices + are equal to `indices` minus `offset`. Outputs: Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. Examples: - >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) - >>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32) - >>> axis = 0 + >>> params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) + >>> indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32) >>> offset = 4 - >>> reduce_scatter_flag = False - >>> split_num = 1 - >>> out = P.EmbeddingLookup()(input_params, input_indices, axis, offset, reduce_scatter_flag, split_num) + >>> out = P.EmbeddingLookup()(params, indices, offset) [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] """ @prim_attr_register def __init__(self): """init index_select""" - self.__setattr_flag__ = True - self.init_prim_io_names(inputs=['params', 'indices', 'axis', 'offset', 'reduce_scatter_flag', 'split_num'], + self.init_prim_io_names(inputs=['params', 'indices', 'offset'], outputs=['output']) - self.add_prim_attr('target', 'CPU') - def __infer__(self, params, indices, axis, offset, reduce_scatter_flag=False, split_num=2): + def __infer__(self, params, indices, offset): validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) - validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) - validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) + valid_types = (mstype.int32, mstype.int64) + validator.check_tensor_type_same({"indices": indices['dtype']}, valid_types, self.name) validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) - validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name) - if split_num['value'] < 1: - raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num) - axis_v = axis['value'] params_shp = params['shape'] - rank = len(params_shp) - validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) - if axis_v < 0: - axis_v += rank - out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] - if reduce_scatter_flag: - # partition the tensor along the dimension 0. - if out_shape[0] % split_num['value'] != 0: - raise ValueError("The dimension 0 of the shape: %d, is not divisible by split_num: %d." % - (out_shape[0], split_num['value'])) - out_shape[0] = out_shape[0] // split_num['value'] + out_shape = indices['shape'] + params_shp[1:] out = {'shape': out_shape, 'dtype': params['dtype'], 'value': None} diff --git a/tests/st/ops/ascend/test_embedding_lookup.py b/tests/st/ops/ascend/test_embedding_lookup.py new file mode 100644 index 0000000000..483fdcdbc4 --- /dev/null +++ b/tests/st/ops/ascend/test_embedding_lookup.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.common.dtype as mstype +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, offset): + super(Net, self).__init__() + self.embedding = P.EmbeddingLookup() + self.offset = offset + + def construct(self, param, index): + return self.embedding(param, index, self.offset) + + +def test_embedding_lookup_sparse(): + params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mstype.int32) + indices = Tensor(np.array([[5, 2], [8, 5]]), mstype.int32) + offset = 4 + embedding = Net(offset) + out = embedding(params, indices) + assert(out.asnumpy() == [[[10, 11], [0, 0]], [[0, 0], [10, 11]]]).all() diff --git a/tests/ut/python/parallel/test_embeddinglookup.py b/tests/ut/python/parallel/test_embeddinglookup.py index b934028a48..b306061981 100644 --- a/tests/ut/python/parallel/test_embeddinglookup.py +++ b/tests/ut/python/parallel/test_embeddinglookup.py @@ -19,6 +19,7 @@ import mindspore.nn as nn from mindspore import Tensor from mindspore.common.api import _executor from mindspore.ops import operations as P +from mindspore.ops.operations import _inner_ops as inner from tests.ut.python.ops.test_math_ops import VirtualLoss @@ -33,29 +34,27 @@ class NetWithLoss(nn.Cell): return self.loss(predict) class Net(nn.Cell): - def __init__(self, shape, axis, offset, reduce_scatter_flag, split_num): + def __init__(self, shape, offset, reduce_scatter_flag, split_num): super().__init__() self.index = Tensor(np.ones(shape), dtype=ms.int32) - self.axis = axis self.offset = offset self.reduce_scatter_flag = reduce_scatter_flag self.split_num = split_num - self.elu = P.EmbeddingLookup() + self.elu = inner.EmbeddingLookup() self.mm = P.BatchMatMul() def construct(self, x, y): - out = self.elu(x, self.index, self.axis, self.offset, self.reduce_scatter_flag, self.split_num) + out = self.elu(x, self.index, self.offset, self.reduce_scatter_flag, self.split_num) out = self.mm(out, y) return out def test_embeddinglookup_reducescatter_false(): shape = [8, 8] - axis = 0 offset = 8 reduce_scatter_flag = False split_num = 1 - net = NetWithLoss(Net(shape, axis, offset, reduce_scatter_flag, split_num)) + net = NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)) net.set_auto_parallel() x = Tensor(np.ones([64, 32]), dtype=ms.float32) @@ -64,14 +63,13 @@ def test_embeddinglookup_reducescatter_false(): def test_embeddinglookup_reducescatter_true(): - shape = [8, 8] - axis = 0 + shape = [64, 8] offset = 8 reduce_scatter_flag = True split_num = 8 - net = NetWithLoss(Net(shape, axis, offset, reduce_scatter_flag, split_num)) + net = NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)) net.set_auto_parallel() x = Tensor(np.ones([64, 32]), dtype=ms.float32) - y = Tensor(np.ones([1, 32, 8]), dtype=ms.float32) + y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32) _executor.compile(net, x, y)