| @@ -132,6 +132,7 @@ REGISTER(SqueezeInfo); | |||||
| REGISTER(SigmoidCrossEntropyWithLogitsInfo); | REGISTER(SigmoidCrossEntropyWithLogitsInfo); | ||||
| REGISTER(SquareInfo); | REGISTER(SquareInfo); | ||||
| REGISTER(GatherV2PInfo); | REGISTER(GatherV2PInfo); | ||||
| REGISTER(EmbeddingLookupInfo); | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -28,24 +28,25 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| Status GatherV2PInfo::GetAttrs() { | Status GatherV2PInfo::GetAttrs() { | ||||
| // get axis, the third input is the axis, is a ValueNode | |||||
| if (input_value_.at(2) == nullptr) { | |||||
| MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; | |||||
| return FAILED; | |||||
| } | |||||
| auto axis = GetValue<int>(input_value_.at(2)); | |||||
| // if axis is negative then convert it to positive | |||||
| auto params_shape = inputs_shape_.at(0); | |||||
| if (params_shape.size() == 0) { | |||||
| MS_LOG(ERROR) << name_ << ": params can not be a scalar!"; | |||||
| return FAILED; | |||||
| } | |||||
| if (axis < 0) { | |||||
| axis += SizeToInt(inputs_shape_[0].size()); | |||||
| // get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis. | |||||
| if (target_ != CPU) { | |||||
| if (input_value_.at(2) == nullptr) { | |||||
| MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; | |||||
| return FAILED; | |||||
| } | |||||
| auto axis = GetValue<int>(input_value_.at(2)); | |||||
| // if axis is negative then convert it to positive | |||||
| auto params_shape = inputs_shape_.at(0); | |||||
| if (params_shape.size() == 0) { | |||||
| MS_LOG(ERROR) << name_ << ": params can not be a scalar!"; | |||||
| return FAILED; | |||||
| } | |||||
| if (axis < 0) { | |||||
| axis += SizeToInt(inputs_shape_[0].size()); | |||||
| } | |||||
| axis_ = axis; | |||||
| } | } | ||||
| axis_ = axis; | |||||
| // get target | |||||
| auto target_iter = attrs_.find(TARGET); | auto target_iter = attrs_.find(TARGET); | ||||
| if (target_iter != attrs_.end()) { | if (target_iter != attrs_.end()) { | ||||
| MS_EXCEPTION_IF_NULL(target_iter->second); | MS_EXCEPTION_IF_NULL(target_iter->second); | ||||
| @@ -53,16 +54,8 @@ Status GatherV2PInfo::GetAttrs() { | |||||
| target_ = target_iter->second->cast<StringImmPtr>()->value(); | target_ = target_iter->second->cast<StringImmPtr>()->value(); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << name_ << " : The value of target is not a string."; | MS_LOG(ERROR) << name_ << " : The value of target is not a string."; | ||||
| return FAILED; | |||||
| } | } | ||||
| } | } | ||||
| // target=CPU, axis must be 0 | |||||
| if (target_ == "CPU" && axis_ != 0) { | |||||
| MS_LOG(ERROR) << name_ << ": target is CPU, axis must be 0, but got " << axis_; | |||||
| return FAILED; | |||||
| } | |||||
| auto manual_split_iter = attrs_.find("manual_split"); | auto manual_split_iter = attrs_.find("manual_split"); | ||||
| if (manual_split_iter != attrs_.end()) { | if (manual_split_iter != attrs_.end()) { | ||||
| param_split_shapes_.clear(); | param_split_shapes_.clear(); | ||||
| @@ -459,38 +452,13 @@ Status GatherV2PInfo::InferForwardCommunication() { | |||||
| MS_LOG(ERROR) << name_ << ": Infer Group failed."; | MS_LOG(ERROR) << name_ << ": Infer Group failed."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| auto group_size = group_.GetDevNum(); | |||||
| Attr attr_group; | Attr attr_group; | ||||
| if (host_reduce_scatter_) { | |||||
| // group size <= 8 | |||||
| std::vector<int32_t> rank_list; | |||||
| if (group_size <= 8) { | |||||
| reduce_scatter_flag_ = false; | |||||
| operator_name = HOST_REDUCE_SCATTER; | |||||
| rank_list = GetRankFromGroup(group_); | |||||
| attr_group = std::make_pair(GROUP, MakeValue(rank_list)); | |||||
| } else { | |||||
| // group size > 8, don't support host reduce_scatter | |||||
| reduce_scatter_flag_ = true; | |||||
| split_num_ = SizeToInt(group_size / 8); | |||||
| CheckGlobalDeviceManager(); | |||||
| operator_name = REDUCE_SCATTER; | |||||
| int32_t rank = g_device_manager->global_rank(); | |||||
| size_t repeat = group_size / 8; | |||||
| for (size_t i = 0; i < repeat; ++i) { | |||||
| rank_list.push_back(rank + SizeToInt(i * 8)); | |||||
| } | |||||
| Group g = g_device_manager->CreateGroup(rank_list); | |||||
| attr_group = std::make_pair(GROUP, MakeValue(g.name())); | |||||
| } | |||||
| } else { | |||||
| operator_name = REDUCE_SCATTER; | |||||
| if (InferGroup() != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Infer Group failed."; | |||||
| return FAILED; | |||||
| } | |||||
| attr_group = std::make_pair(GROUP, MakeValue(group_.name())); | |||||
| operator_name = REDUCE_SCATTER; | |||||
| if (InferGroup() != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Infer Group failed."; | |||||
| return FAILED; | |||||
| } | } | ||||
| attr_group = std::make_pair(GROUP, MakeValue(group_.name())); | |||||
| Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); | Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); | ||||
| OperatorAttrs attrs = {attr_op, attr_group}; | OperatorAttrs attrs = {attr_op, attr_group}; | ||||
| OperatorParams params; | OperatorParams params; | ||||
| @@ -582,10 +550,7 @@ Status GatherV2PInfo::ComputeReplaceOp() { | |||||
| OperatorName op_name = EMBEDDING_LOOKUP; | OperatorName op_name = EMBEDDING_LOOKUP; | ||||
| OperatorAttrs attrs; | OperatorAttrs attrs; | ||||
| Attr param_offset = std::make_pair("offset", MakeValue(bias_)); | Attr param_offset = std::make_pair("offset", MakeValue(bias_)); | ||||
| Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_)); | |||||
| Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_)); | |||||
| OperatorParams params = {std::make_pair(param_offset, 3), std::make_pair(param_flag, 4), | |||||
| std::make_pair(param_split_num, 5)}; | |||||
| OperatorParams params = {std::make_pair(param_offset, 3)}; | |||||
| OperatorArgs args = std::make_pair(attrs, params); | OperatorArgs args = std::make_pair(attrs, params); | ||||
| Operator op = std::make_pair(op_name, args); | Operator op = std::make_pair(op_name, args); | ||||
| replace_op_.push_back(op); | replace_op_.push_back(op); | ||||
| @@ -65,16 +65,13 @@ class GatherV2PInfo : public OperatorInfo { | |||||
| Status InferGroup(); | Status InferGroup(); | ||||
| int32_t axis_; | int32_t axis_; | ||||
| std::string target_; | |||||
| std::string target_ = DEVICE; | |||||
| std::string replace_op_name_ = GATHERV2; | std::string replace_op_name_ = GATHERV2; | ||||
| int32_t bias_; | int32_t bias_; | ||||
| int32_t index_offset_; | int32_t index_offset_; | ||||
| int32_t slice_size_; | int32_t slice_size_; | ||||
| Shape out_dev_matrix_shape_; | Shape out_dev_matrix_shape_; | ||||
| Group group_; | Group group_; | ||||
| bool reduce_scatter_flag_ = false; | |||||
| int32_t split_num_ = 1; | |||||
| bool host_reduce_scatter_ = false; | |||||
| bool manual_split_ = false; | bool manual_split_ = false; | ||||
| std::vector<int32_t> param_split_shapes_; | std::vector<int32_t> param_split_shapes_; | ||||
| std::vector<int32_t> index_offsets_; | std::vector<int32_t> index_offsets_; | ||||
| @@ -90,6 +87,14 @@ class SparseGatherV2Info : public GatherV2PInfo { | |||||
| private: | private: | ||||
| std::string replace_op_name_ = SPARSE_GATHERV2; | std::string replace_op_name_ = SPARSE_GATHERV2; | ||||
| }; | }; | ||||
| class EmbeddingLookupInfo : public GatherV2PInfo { | |||||
| public: | |||||
| EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||||
| const PrimitiveAttrs &attrs) | |||||
| : GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {} | |||||
| ~EmbeddingLookupInfo() override = default; | |||||
| }; | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ | #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ | ||||
| @@ -132,6 +132,7 @@ constexpr char REDISTRIBUTION_OP[] = "redistribution_op"; | |||||
| constexpr char DARA_PARALLEL[] = "data_parallel"; | constexpr char DARA_PARALLEL[] = "data_parallel"; | ||||
| constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter"; | constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter"; | ||||
| constexpr char OPTIMIZER_SUB_STRING[] = "optimizer"; | constexpr char OPTIMIZER_SUB_STRING[] = "optimizer"; | ||||
| constexpr char DEVICE[] = "Device"; | |||||
| // Operator | // Operator | ||||
| constexpr char VIRTUAL_DIV[] = "_VirtualDiv"; | constexpr char VIRTUAL_DIV[] = "_VirtualDiv"; | ||||
| @@ -536,7 +536,7 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st | |||||
| } | } | ||||
| std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)}; | std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)}; | ||||
| auto prim = GetValueNode<PrimitivePtr>(node->input(0)); | auto prim = GetValueNode<PrimitivePtr>(node->input(0)); | ||||
| if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) { | |||||
| if (prim->name() == EMBEDDING_LOOKUP) { | |||||
| replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; | replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; | ||||
| } | } | ||||
| if (!params.empty()) { | if (!params.empty()) { | ||||
| @@ -105,3 +105,49 @@ class Embedding(Cell): | |||||
| self.embedding_table, | self.embedding_table, | ||||
| self.dtype) | self.dtype) | ||||
| return s | return s | ||||
| class EmbeddingLookup(Cell): | |||||
| r""" | |||||
| Returns a slice of input tensor based on the specified indices. | |||||
| Note: | |||||
| When 'target' is set to 'CPU', this module will use | |||||
| P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which | |||||
| specified 'offset = 0' to lookup table. | |||||
| when 'target' is set to 'DEVICE', this module will use P.GatherV2() which | |||||
| specified 'axis = 0' to lookup table. | |||||
| Args: | |||||
| target (str): Specify the target where the op is executed. Default: 'CPU'. | |||||
| Inputs: | |||||
| - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||||
| The Tensor slice, instead of the entire Tensor. | |||||
| - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | |||||
| Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`, | |||||
| and the exceeding part will be filled with 0 in the output. | |||||
| Outputs: | |||||
| Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. | |||||
| Examples: | |||||
| >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) | |||||
| >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32) | |||||
| >>> out = nn.EmbeddingLookup()(input_params, input_indices) | |||||
| [[[10, 11], [8 ,9]], [[14, 15], [12, 13]]] | |||||
| """ | |||||
| def __init__(self, target='CPU'): | |||||
| super(EmbeddingLookup, self).__init__() | |||||
| self.target = target | |||||
| if target not in ('CPU', 'DEVICE'): | |||||
| raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed ' | |||||
| + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') | |||||
| self.gatherv2 = P.GatherV2() | |||||
| self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') | |||||
| def construct(self, params, indices): | |||||
| if self.target == "CPU": | |||||
| out = self.embeddinglookup(params, ids, 0) | |||||
| else: | |||||
| out = self.gatherv2(param, ids, 0) | |||||
| return out | |||||
| @@ -188,7 +188,7 @@ class WideDeepModel(nn.Cell): | |||||
| self.deep_layer_act, | self.deep_layer_act, | ||||
| use_activation=False, convert_dtype=True, drop_out=config.dropout_flag) | use_activation=False, convert_dtype=True, drop_out=config.dropout_flag) | ||||
| self.gather_v2 = P.GatherV2() | |||||
| self.embeddinglookup = nn.EmbeddingLookup() | |||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | self.reduce_sum = P.ReduceSum(keep_dims=False) | ||||
| self.reshape = P.Reshape() | self.reshape = P.Reshape() | ||||
| @@ -206,11 +206,11 @@ class WideDeepModel(nn.Cell): | |||||
| """ | """ | ||||
| mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) | mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) | ||||
| # Wide layer | # Wide layer | ||||
| wide_id_weight = self.gather_v2(self.wide_w, id_hldr, 0) | |||||
| wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr, 0) | |||||
| wx = self.mul(wide_id_weight, mask) | wx = self.mul(wide_id_weight, mask) | ||||
| wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) | wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) | ||||
| # Deep layer | # Deep layer | ||||
| deep_id_embs = self.gather_v2(self.embedding_table, id_hldr, 0) | |||||
| deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr, 0) | |||||
| vx = self.mul(deep_id_embs, mask) | vx = self.mul(deep_id_embs, mask) | ||||
| deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim)) | deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim)) | ||||
| deep_in = self.dense_layer_1(deep_in) | deep_in = self.dense_layer_1(deep_in) | ||||
| @@ -41,12 +41,12 @@ class NetWithLoss(nn.Cell): | |||||
| return self.loss(predict) | return self.loss(predict) | ||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self, shape, offset): | |||||
| def __init__(self, shape, offset, strategy1=None, strategy2=None, target="Device"): | |||||
| super().__init__() | super().__init__() | ||||
| self.index = Tensor(np.ones(shape), dtype=ms.int32) | self.index = Tensor(np.ones(shape), dtype=ms.int32) | ||||
| self.offset = offset | self.offset = offset | ||||
| self.elu = P.EmbeddingLookup() | |||||
| self.mm = P.BatchMatMul() | |||||
| self.elu = P.EmbeddingLookup().set_strategy(strategy1).add_prim_attr("primitive_target", target) | |||||
| self.mm = P.BatchMatMul().set_strategy(strategy2) | |||||
| def construct(self, x, y): | def construct(self, x, y): | ||||
| out = self.elu(x, self.index, self.offset) | out = self.elu(x, self.index, self.offset) | ||||
| @@ -97,3 +97,31 @@ def test_embeddinglookup_reducescatter_true_grad(): | |||||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | x = Tensor(np.ones([64, 32]), dtype=ms.float32) | ||||
| y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32) | y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32) | ||||
| _executor.compile(net, x, y) | _executor.compile(net, x, y) | ||||
| def test_embeddinglookup_semi_auto1(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||||
| shape = [64, 32] | |||||
| offset = 0 | |||||
| strategy1 = ((8, 1), (1, 1)) | |||||
| strategy2 = ((4, 1, 2), (4, 2, 1)) | |||||
| net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU"))) | |||||
| net.set_auto_parallel() | |||||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) | |||||
| _executor.compile(net, x, y) | |||||
| def test_embeddinglookup_semi_auto2(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||||
| shape = [64, 32] | |||||
| offset = 0 | |||||
| strategy1 = ((1, 8), (1, 1)) | |||||
| strategy2 = ((4, 1, 2), (4, 2, 1)) | |||||
| net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU"))) | |||||
| net.set_auto_parallel() | |||||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) | |||||
| _executor.compile(net, x, y) | |||||
| @@ -13,8 +13,6 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| @@ -183,42 +181,3 @@ def test_gatherv2_auto1(): | |||||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | x = Tensor(np.ones([64, 32]), dtype=ms.float32) | ||||
| y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) | y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) | ||||
| _executor.compile(net, x, y) | _executor.compile(net, x, y) | ||||
| @pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") | |||||
| def test_gatherv2_cpu0(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||||
| strategy1 = ((8, 1), (1, 1)) | |||||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||||
| net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) | |||||
| net.set_auto_parallel() | |||||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) | |||||
| _executor.compile(net, x, y) | |||||
| @pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") | |||||
| def test_gatherv2_cpu1(): | |||||
| context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") | |||||
| strategy1 = ((16, 1), (1, 1)) | |||||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||||
| net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) | |||||
| net.set_auto_parallel() | |||||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) | |||||
| _executor.compile(net, x, y) | |||||
| @pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") | |||||
| def test_gatherv2_cpu2(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||||
| strategy1 = ((1, 8), (1, 1)) | |||||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||||
| net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) | |||||
| net.set_auto_parallel() | |||||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) | |||||
| _executor.compile(net, x, y) | |||||