Merge pull request !236 from wuxuejian/incu_embeddingtags/v0.6.0-beta
| @@ -17,6 +17,7 @@ | |||||
| from .. import operations as P | from .. import operations as P | ||||
| from ..operations import _grad_ops as G | from ..operations import _grad_ops as G | ||||
| from ..operations import _inner_ops as inner | |||||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | from ..composite.multitype_ops.zeros_like_impl import zeros_like | ||||
| from .. import functional as F | from .. import functional as F | ||||
| from .grad_base import bprop_getters | from .grad_base import bprop_getters | ||||
| @@ -188,6 +189,31 @@ def get_bprop_tile(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(inner.EmbeddingLookup) | |||||
| def get_bprop_embedding_lookup(self): | |||||
| """Generate bprop for EmbeddingLookup""" | |||||
| host_sub = P.Sub().add_prim_attr('primitive_target', 'CPU') | |||||
| host_reshape = P.Reshape().add_prim_attr('primitive_target', 'CPU') | |||||
| def bprop_sparse(x, indices, offset, reduce_scatter_flag, split_num, out, dout): | |||||
| x_shp = shape_op(x) | |||||
| if reduce_scatter_flag is True: | |||||
| elu_grad = G.EmbeddingLookupCommGrad() | |||||
| actual_dout = elu_grad(dout, split_num) | |||||
| else: | |||||
| actual_dout = dout | |||||
| new_indices = host_sub(indices - offset) | |||||
| # Reshape the 'new_indices' | |||||
| new_indices_shape_changed = (size_op(new_indices),) | |||||
| new_indices = host_reshape(new_indices, new_indices_shape_changed) | |||||
| # Reshape the 'actual_dout' | |||||
| x_shp_tail = x_shp[1:] | |||||
| actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail | |||||
| actual_dout = host_reshape(actual_dout, actual_dout_shape_changed) | |||||
| return (new_indices, actual_dout, x_shp), zeros_like(new_indices), zeros_like(axis), \ | |||||
| zeros_like(reduce_scatter_flag), zeros_like(split_num) | |||||
| return bprop_sparse | |||||
| @bprop_getters.register(P.Transpose) | @bprop_getters.register(P.Transpose) | ||||
| def get_bprop_transpose(self): | def get_bprop_transpose(self): | ||||
| """Generate bprop for Transpose""" | """Generate bprop for Transpose""" | ||||
| @@ -14,6 +14,7 @@ | |||||
| """aicpu ops""" | """aicpu ops""" | ||||
| from .init_data_set_queue import _init_data_set_queue_aicpu | from .init_data_set_queue import _init_data_set_queue_aicpu | ||||
| from .embedding_lookup import _embedding_lookup_aicpu | |||||
| from .dropout_genmask import _dropout_genmask_aicpu | from .dropout_genmask import _dropout_genmask_aicpu | ||||
| from .get_next import _get_next_aicpu | from .get_next import _get_next_aicpu | ||||
| from .print_tensor import _print_aicpu | from .print_tensor import _print_aicpu | ||||
| @@ -0,0 +1,102 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """EmbeddingLookup op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| embeddingLookup_op_info = AiCPURegOp("EmbeddingLookup") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "params", "required") \ | |||||
| .input(1, "indices", "required") \ | |||||
| .input(2, "offset", "required") \ | |||||
| .output(0, "output", "required") \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default, \ | |||||
| DataType.I32_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I64_Default, \ | |||||
| DataType.I64_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I64_Default, \ | |||||
| DataType.I32_Default, DataType.BOOL_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(embeddingLookup_op_info) | |||||
| def _embedding_lookup_aicpu(): | |||||
| """EmbeddingLookup AiCPU register""" | |||||
| return | |||||
| @@ -96,3 +96,73 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||||
| """infer dtype""" | """infer dtype""" | ||||
| validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) | validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) | ||||
| return input_x | return input_x | ||||
| class EmbeddingLookup(PrimitiveWithInfer): | |||||
| """ | |||||
| Returns a slice of input tensor based on the specified indices. | |||||
| This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has three more inputs: | |||||
| `offset`, `reduce_scatter_flag` and `split_num`. This primitive runs on the host instead of devices. | |||||
| Inputs: | |||||
| - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||||
| The Tensor slice, instead of the entire Tensor. | |||||
| - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | |||||
| Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`, | |||||
| and the exceeding part will be filled with 0 in the output. | |||||
| - **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices | |||||
| are equal to `input_indices` minus `offset`. | |||||
| - **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not. | |||||
| Only constant value is allowed. | |||||
| - **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable | |||||
| is used only if `reduce_scatter_flag` is True. Only constant value is allowed. | |||||
| Outputs: | |||||
| Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. | |||||
| Examples: | |||||
| >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) | |||||
| >>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32) | |||||
| >>> offset = 4 | |||||
| >>> reduce_scatter_flag = False | |||||
| >>> split_num = 1 | |||||
| >>> out = P.EmbeddingLookup()(input_params, input_indices, offset, reduce_scatter_flag, split_num) | |||||
| [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init index_select""" | |||||
| self.__setattr_flag__ = True | |||||
| self.init_prim_io_names(inputs=['params', 'indices', 'offset', 'reduce_scatter_flag', 'split_num'], | |||||
| outputs=['output']) | |||||
| self.add_prim_attr('primitive_target', 'CPU') | |||||
| def __infer__(self, params, indices, offset, reduce_scatter_flag=False, split_num=2): | |||||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||||
| validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) | |||||
| validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name) | |||||
| if split_num['value'] < 1: | |||||
| raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num) | |||||
| params_shp = params['shape'] | |||||
| out_shape = indices['shape'] + params_shp[1:] | |||||
| if reduce_scatter_flag is None: | |||||
| raise ValueError("The value of 'reduce_scatter_flag' is None.") | |||||
| reduce_scatter_flag_value = reduce_scatter_flag['value'] | |||||
| if split_num is None: | |||||
| raise ValueError("The value of 'split_num_value' is None.") | |||||
| split_num_value = split_num['value'] | |||||
| if reduce_scatter_flag_value is True: | |||||
| # Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by | |||||
| # (split_num * 8) | |||||
| if out_shape[0] % (split_num_value * 8) != 0: | |||||
| raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." % | |||||
| (out_shape[0], (split_num_value * 8))) | |||||
| # After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8 | |||||
| out_shape[0] = out_shape[0] // 8 | |||||
| out = {'shape': out_shape, | |||||
| 'dtype': params['dtype'], | |||||
| 'value': None} | |||||
| return out | |||||
| @@ -577,64 +577,43 @@ class Range(PrimitiveWithInfer): | |||||
| class EmbeddingLookup(PrimitiveWithInfer): | class EmbeddingLookup(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar | Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar | ||||
| functionality as GatherV2, but has three more inputs: `offset`, `reduce_scatter_flag` and `split_num`. | |||||
| functionality as GatherV2, but has one more inputs: `offset`. | |||||
| This primitive runs on the acipu devices. | |||||
| Inputs: | Inputs: | ||||
| - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||||
| - **params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||||
| The Tensor slice, instead of the entire Tensor. | The Tensor slice, instead of the entire Tensor. | ||||
| - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | |||||
| Specifies the indices of elements of the original Tensor. Must be in the range | |||||
| `[0, input_param.shape()[axis])`. | |||||
| - **axis** (int) - Specifies the dimension index to gather indices. | |||||
| - **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices | |||||
| are equal to `input_indices` minus `offset`. | |||||
| - **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not. | |||||
| - **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable | |||||
| is used only if `reduce_scatter_flag` is True. | |||||
| - **indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | |||||
| Specifies the indices of elements of the original Tensor. Values can be out of range of `params`, | |||||
| and the exceeding part will be filled with 0 in the output. | |||||
| The indices to do lookup operation whose data type should be mindspore.int32 or mindspore.int64. | |||||
| - **offset** (int) - Specifies the offset value of this `params` slice. Thus the real indices | |||||
| are equal to `indices` minus `offset`. | |||||
| Outputs: | Outputs: | ||||
| Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. | Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. | ||||
| Examples: | Examples: | ||||
| >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) | |||||
| >>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32) | |||||
| >>> axis = 0 | |||||
| >>> params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) | |||||
| >>> indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32) | |||||
| >>> offset = 4 | >>> offset = 4 | ||||
| >>> reduce_scatter_flag = False | |||||
| >>> split_num = 1 | |||||
| >>> out = P.EmbeddingLookup()(input_params, input_indices, axis, offset, reduce_scatter_flag, split_num) | |||||
| >>> out = P.EmbeddingLookup()(params, indices, offset) | |||||
| [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] | [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| """init index_select""" | """init index_select""" | ||||
| self.__setattr_flag__ = True | |||||
| self.init_prim_io_names(inputs=['params', 'indices', 'axis', 'offset', 'reduce_scatter_flag', 'split_num'], | |||||
| self.init_prim_io_names(inputs=['params', 'indices', 'offset'], | |||||
| outputs=['output']) | outputs=['output']) | ||||
| self.add_prim_attr('target', 'CPU') | |||||
| def __infer__(self, params, indices, axis, offset, reduce_scatter_flag=False, split_num=2): | |||||
| def __infer__(self, params, indices, offset): | |||||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | ||||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||||
| validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) | |||||
| valid_types = (mstype.int32, mstype.int64) | |||||
| validator.check_tensor_type_same({"indices": indices['dtype']}, valid_types, self.name) | |||||
| validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) | validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) | ||||
| validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name) | |||||
| if split_num['value'] < 1: | |||||
| raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num) | |||||
| axis_v = axis['value'] | |||||
| params_shp = params['shape'] | params_shp = params['shape'] | ||||
| rank = len(params_shp) | |||||
| validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) | |||||
| if axis_v < 0: | |||||
| axis_v += rank | |||||
| out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] | |||||
| if reduce_scatter_flag: | |||||
| # partition the tensor along the dimension 0. | |||||
| if out_shape[0] % split_num['value'] != 0: | |||||
| raise ValueError("The dimension 0 of the shape: %d, is not divisible by split_num: %d." % | |||||
| (out_shape[0], split_num['value'])) | |||||
| out_shape[0] = out_shape[0] // split_num['value'] | |||||
| out_shape = indices['shape'] + params_shp[1:] | |||||
| out = {'shape': out_shape, | out = {'shape': out_shape, | ||||
| 'dtype': params['dtype'], | 'dtype': params['dtype'], | ||||
| 'value': None} | 'value': None} | ||||
| @@ -0,0 +1,42 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, | |||||
| device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, offset): | |||||
| super(Net, self).__init__() | |||||
| self.embedding = P.EmbeddingLookup() | |||||
| self.offset = offset | |||||
| def construct(self, param, index): | |||||
| return self.embedding(param, index, self.offset) | |||||
| def test_embedding_lookup_sparse(): | |||||
| params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mstype.int32) | |||||
| indices = Tensor(np.array([[5, 2], [8, 5]]), mstype.int32) | |||||
| offset = 4 | |||||
| embedding = Net(offset) | |||||
| out = embedding(params, indices) | |||||
| assert(out.asnumpy() == [[[10, 11], [0, 0]], [[0, 0], [10, 11]]]).all() | |||||
| @@ -19,6 +19,7 @@ import mindspore.nn as nn | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.common.api import _executor | from mindspore.common.api import _executor | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops.operations import _inner_ops as inner | |||||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | from tests.ut.python.ops.test_math_ops import VirtualLoss | ||||
| @@ -33,29 +34,27 @@ class NetWithLoss(nn.Cell): | |||||
| return self.loss(predict) | return self.loss(predict) | ||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self, shape, axis, offset, reduce_scatter_flag, split_num): | |||||
| def __init__(self, shape, offset, reduce_scatter_flag, split_num): | |||||
| super().__init__() | super().__init__() | ||||
| self.index = Tensor(np.ones(shape), dtype=ms.int32) | self.index = Tensor(np.ones(shape), dtype=ms.int32) | ||||
| self.axis = axis | |||||
| self.offset = offset | self.offset = offset | ||||
| self.reduce_scatter_flag = reduce_scatter_flag | self.reduce_scatter_flag = reduce_scatter_flag | ||||
| self.split_num = split_num | self.split_num = split_num | ||||
| self.elu = P.EmbeddingLookup() | |||||
| self.elu = inner.EmbeddingLookup() | |||||
| self.mm = P.BatchMatMul() | self.mm = P.BatchMatMul() | ||||
| def construct(self, x, y): | def construct(self, x, y): | ||||
| out = self.elu(x, self.index, self.axis, self.offset, self.reduce_scatter_flag, self.split_num) | |||||
| out = self.elu(x, self.index, self.offset, self.reduce_scatter_flag, self.split_num) | |||||
| out = self.mm(out, y) | out = self.mm(out, y) | ||||
| return out | return out | ||||
| def test_embeddinglookup_reducescatter_false(): | def test_embeddinglookup_reducescatter_false(): | ||||
| shape = [8, 8] | shape = [8, 8] | ||||
| axis = 0 | |||||
| offset = 8 | offset = 8 | ||||
| reduce_scatter_flag = False | reduce_scatter_flag = False | ||||
| split_num = 1 | split_num = 1 | ||||
| net = NetWithLoss(Net(shape, axis, offset, reduce_scatter_flag, split_num)) | |||||
| net = NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)) | |||||
| net.set_auto_parallel() | net.set_auto_parallel() | ||||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | x = Tensor(np.ones([64, 32]), dtype=ms.float32) | ||||
| @@ -64,14 +63,13 @@ def test_embeddinglookup_reducescatter_false(): | |||||
| def test_embeddinglookup_reducescatter_true(): | def test_embeddinglookup_reducescatter_true(): | ||||
| shape = [8, 8] | |||||
| axis = 0 | |||||
| shape = [64, 8] | |||||
| offset = 8 | offset = 8 | ||||
| reduce_scatter_flag = True | reduce_scatter_flag = True | ||||
| split_num = 8 | split_num = 8 | ||||
| net = NetWithLoss(Net(shape, axis, offset, reduce_scatter_flag, split_num)) | |||||
| net = NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)) | |||||
| net.set_auto_parallel() | net.set_auto_parallel() | ||||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | x = Tensor(np.ones([64, 32]), dtype=ms.float32) | ||||
| y = Tensor(np.ones([1, 32, 8]), dtype=ms.float32) | |||||
| y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32) | |||||
| _executor.compile(net, x, y) | _executor.compile(net, x, y) | ||||