| @@ -179,6 +179,7 @@ REGISTER(SquareInfo); | |||||
| REGISTER(UniformCandidateSamplerInfo); | REGISTER(UniformCandidateSamplerInfo); | ||||
| REGISTER(UnsortedSegmentSumInfo); | REGISTER(UnsortedSegmentSumInfo); | ||||
| REGISTER(UnsortedSegmentMinInfo); | REGISTER(UnsortedSegmentMinInfo); | ||||
| REGISTER(UnsortedSegmentMaxInfo); | |||||
| REGISTER(GatherV2PInfo); | REGISTER(GatherV2PInfo); | ||||
| REGISTER(EmbeddingLookupInfo); | REGISTER(EmbeddingLookupInfo); | ||||
| REGISTER(TileInfo); | REGISTER(TileInfo); | ||||
| @@ -305,6 +305,7 @@ constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD"; | |||||
| constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD"; | constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD"; | ||||
| constexpr char UNSORTED_SEGMENT_SUM[] = "UnsortedSegmentSum"; | constexpr char UNSORTED_SEGMENT_SUM[] = "UnsortedSegmentSum"; | ||||
| constexpr char UNSORTED_SEGMENT_MIN[] = "UnsortedSegmentMin"; | constexpr char UNSORTED_SEGMENT_MIN[] = "UnsortedSegmentMin"; | ||||
| constexpr char UNSORTED_SEGMENT_MAX[] = "UnsortedSegmentMax"; | |||||
| constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative"; | constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative"; | ||||
| constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D"; | constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D"; | ||||
| constexpr char ADD[] = "Add"; | constexpr char ADD[] = "Add"; | ||||
| @@ -332,5 +332,41 @@ Status UnsortedSegmentMinInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| // The UnsortedSegmentMaxInfo is almost same with UnsortedSegmentMinInfo | |||||
| // Except the reduceMin op in the ComputeReplaceGraph is replaced with reduceMax op | |||||
| ReplaceGraphPtr UnsortedSegmentMaxInfo::replace_graph(const CNodePtr &cnode) { | |||||
| auto input_id_strategy = strategy_->GetInputDim().at(1); | |||||
| // 1. the two input shapes are same, and the strategy is not all ones | |||||
| if (std::any_of(input_id_strategy.begin(), input_id_strategy.end(), [](const int64_t &shard) { return shard > 1; })) { | |||||
| if (ComputeReplaceGraph(cnode) != SUCCESS) { | |||||
| MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; | |||||
| } | |||||
| } | |||||
| return replace_graph_; | |||||
| } | |||||
| Status UnsortedSegmentMaxInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||||
| GenerateGraph gen_g = GenerateGraph(); | |||||
| if (gen_g.Init(cnode) != SUCCESS) { | |||||
| MS_LOG(ERROR) << "GenerateGraph Init failed"; | |||||
| return FAILED; | |||||
| } | |||||
| // Get the attributes of the UnsortedSegmentMin | |||||
| auto num_segments = GetValue<int64_t>(input_value_.at(2)); | |||||
| // Step1: Output branch | |||||
| auto segment_max = gen_g.PushBack({gen_g.NewOpInst(UNSORTED_SEGMENT_MAX), gen_g.virtual_input_node(), | |||||
| gen_g.virtual_input_node(), CreatInt64Imm(num_segments)}); | |||||
| auto expandim_output = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), segment_max, CreatInt64Imm(0)}); | |||||
| auto all_gather_output = gen_g.PushBack({gen_g.NewOpInst(ALL_GATHER), expandim_output}); | |||||
| auto final_output = gen_g.PushBack({gen_g.NewOpInst(REDUCE_MAX), all_gather_output, CreatInt64Imm(0)}); | |||||
| std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(segment_max, 1), | |||||
| std::make_pair(segment_max, 2)}; | |||||
| replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>( | |||||
| std::make_pair(input_nodes, final_output)); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -79,6 +79,20 @@ class UnsortedSegmentMinInfo : public UnsortedSegmentOpInfo { | |||||
| Status ComputeReplaceGraph(const CNodePtr &cnode); | Status ComputeReplaceGraph(const CNodePtr &cnode); | ||||
| }; | }; | ||||
| class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo { | |||||
| public: | |||||
| UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||||
| const PrimitiveAttrs &attrs) | |||||
| : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMinCost>()) {} | |||||
| ~UnsortedSegmentMaxInfo() override = default; | |||||
| ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; | |||||
| Status InferForwardCommunication() override { return SUCCESS; } | |||||
| protected: | |||||
| Status ComputeReplaceGraph(const CNodePtr &cnode); | |||||
| }; | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNSORTEDSEGMENTOP_INFO_H_ | #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNSORTEDSEGMENTOP_INFO_H_ | ||||
| @@ -317,7 +317,8 @@ bool IsSplittableOperator(const std::string &op_name) { | |||||
| EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE, | EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE, | ||||
| BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2, | BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2, | ||||
| SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM, | SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM, | ||||
| UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE}; | |||||
| UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, | |||||
| UNSORTED_SEGMENT_MAX}; | |||||
| // clang-format on | // clang-format on | ||||
| auto iter = splittable_op.find(op_name); | auto iter = splittable_op.find(op_name); | ||||
| @@ -0,0 +1,162 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore as ms | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.common.api import _executor | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops.operations.comm_ops import _VirtualDataset | |||||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| grad_all = C.GradOperation(get_all=True) | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, strategy1, strategy2, num_segments): | |||||
| super(Net, self).__init__() | |||||
| self.virtual_dataset = _VirtualDataset() | |||||
| self.merge_op = P.UnsortedSegmentMax().shard((strategy1, strategy2)) | |||||
| self.num_segments = num_segments | |||||
| def construct(self, vectors, segment_ids): | |||||
| predict = self.merge_op(vectors, segment_ids, self.num_segments) | |||||
| return predict | |||||
| class GradWrap(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(GradWrap, self).__init__() | |||||
| self.network = network | |||||
| def construct(self, x, y): | |||||
| return grad_all(self.network)(x, y) | |||||
| class NetWithLoss(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(NetWithLoss, self).__init__() | |||||
| self.loss = VirtualLoss() | |||||
| self.network = network | |||||
| def construct(self, x, y): | |||||
| predict = self.network(x, y) | |||||
| return self.loss(predict) | |||||
| def compile_graph(x, y, segments, strategy1, strategy2, auto=False): | |||||
| net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments))) | |||||
| net.set_auto_parallel() | |||||
| net.set_train() | |||||
| if auto: | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||||
| else: | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||||
| _executor.compile(net, x, y) | |||||
| def test_UnsortedSegmentMax_model_parallel_slice_1d(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||||
| x = Tensor(np.ones(8), ms.float32) | |||||
| y = Tensor(np.ones(8), ms.int32) | |||||
| num_segments = 16 | |||||
| strategy1 = (8,) | |||||
| strategy2 = (8,) | |||||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||||
| def test_UnsortedSegmentMax_model_parallel_no_slice_1d(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||||
| x = Tensor(np.ones(8), ms.float32) | |||||
| y = Tensor(np.ones(8), ms.int32) | |||||
| num_segments = 16 | |||||
| strategy1 = (1,) | |||||
| strategy2 = (1,) | |||||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||||
| def test_UnsortedSegmentMax_model_parallel_index_slice_2d(): | |||||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||||
| x = Tensor(np.ones((4, 8)), ms.float32) | |||||
| y = Tensor(np.arange(4), ms.int32) | |||||
| num_segments = 4 | |||||
| strategy1 = (4, 1) | |||||
| strategy2 = (4,) | |||||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||||
| def test_UnsortedSegmentMax_model_parallel_vector_slice_2d(): | |||||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||||
| x = Tensor(np.ones((4, 8)), ms.float32) | |||||
| y = Tensor(np.ones(4), ms.int32) | |||||
| num_segments = 4 | |||||
| strategy1 = (1, 4) | |||||
| strategy2 = (1,) | |||||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||||
| def test_UnsortedSegmentMax_model_parallel_vector_slice_3d(): | |||||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||||
| x = Tensor(np.ones((4, 8, 8)), ms.float32) | |||||
| y = Tensor(np.ones(4), ms.int32) | |||||
| num_segments = 4 | |||||
| strategy1 = (1, 2, 2) | |||||
| strategy2 = (1,) | |||||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||||
| def test_UnsortedSegmentMax_model_parallel_index_vector_slice_2d(): | |||||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||||
| x = Tensor(np.ones((4, 8)), ms.float32) | |||||
| y = Tensor(np.ones(4), ms.int32) | |||||
| num_segments = 4 | |||||
| strategy1 = (2, 2) | |||||
| strategy2 = (2,) | |||||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||||
| def test_UnsortedSegmentMax_model_parallel_index_vector_slice_3d(): | |||||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||||
| x = Tensor(np.ones((4, 4, 8)), ms.float32) | |||||
| y = Tensor(np.ones((4)), ms.int32) | |||||
| num_segments = 16 | |||||
| strategy1 = (2, 1, 2) | |||||
| strategy2 = (2,) | |||||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||||
| def test_UnsortedSegmentMax_model_parallel_float16(): | |||||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||||
| x = Tensor(np.ones((4, 4, 8)), ms.float16) | |||||
| y = Tensor(np.ones((4)), ms.int32) | |||||
| num_segments = 16 | |||||
| strategy1 = (2, 1, 2) | |||||
| strategy2 = (2,) | |||||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||||
| def test_UnsortedSegmentMax_model_parallel_int32(): | |||||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||||
| x = Tensor(np.ones((4, 4, 8)), ms.int32) | |||||
| y = Tensor(np.ones((4)), ms.int32) | |||||
| num_segments = 16 | |||||
| strategy1 = (2, 1, 2) | |||||
| strategy2 = (2,) | |||||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -11,6 +11,7 @@ | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | |||||
| import numpy as np | import numpy as np | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -11,6 +11,7 @@ | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | |||||
| import numpy as np | import numpy as np | ||||