| @@ -179,6 +179,7 @@ REGISTER(SquareInfo); | |||
| REGISTER(UniformCandidateSamplerInfo); | |||
| REGISTER(UnsortedSegmentSumInfo); | |||
| REGISTER(UnsortedSegmentMinInfo); | |||
| REGISTER(UnsortedSegmentMaxInfo); | |||
| REGISTER(GatherV2PInfo); | |||
| REGISTER(EmbeddingLookupInfo); | |||
| REGISTER(TileInfo); | |||
| @@ -305,6 +305,7 @@ constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD"; | |||
| constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD"; | |||
| constexpr char UNSORTED_SEGMENT_SUM[] = "UnsortedSegmentSum"; | |||
| constexpr char UNSORTED_SEGMENT_MIN[] = "UnsortedSegmentMin"; | |||
| constexpr char UNSORTED_SEGMENT_MAX[] = "UnsortedSegmentMax"; | |||
| constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative"; | |||
| constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D"; | |||
| constexpr char ADD[] = "Add"; | |||
| @@ -332,5 +332,41 @@ Status UnsortedSegmentMinInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| return SUCCESS; | |||
| } | |||
| // The UnsortedSegmentMaxInfo is almost same with UnsortedSegmentMinInfo | |||
| // Except the reduceMin op in the ComputeReplaceGraph is replaced with reduceMax op | |||
| ReplaceGraphPtr UnsortedSegmentMaxInfo::replace_graph(const CNodePtr &cnode) { | |||
| auto input_id_strategy = strategy_->GetInputDim().at(1); | |||
| // 1. the two input shapes are same, and the strategy is not all ones | |||
| if (std::any_of(input_id_strategy.begin(), input_id_strategy.end(), [](const int64_t &shard) { return shard > 1; })) { | |||
| if (ComputeReplaceGraph(cnode) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; | |||
| } | |||
| } | |||
| return replace_graph_; | |||
| } | |||
| Status UnsortedSegmentMaxInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| GenerateGraph gen_g = GenerateGraph(); | |||
| if (gen_g.Init(cnode) != SUCCESS) { | |||
| MS_LOG(ERROR) << "GenerateGraph Init failed"; | |||
| return FAILED; | |||
| } | |||
| // Get the attributes of the UnsortedSegmentMin | |||
| auto num_segments = GetValue<int64_t>(input_value_.at(2)); | |||
| // Step1: Output branch | |||
| auto segment_max = gen_g.PushBack({gen_g.NewOpInst(UNSORTED_SEGMENT_MAX), gen_g.virtual_input_node(), | |||
| gen_g.virtual_input_node(), CreatInt64Imm(num_segments)}); | |||
| auto expandim_output = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), segment_max, CreatInt64Imm(0)}); | |||
| auto all_gather_output = gen_g.PushBack({gen_g.NewOpInst(ALL_GATHER), expandim_output}); | |||
| auto final_output = gen_g.PushBack({gen_g.NewOpInst(REDUCE_MAX), all_gather_output, CreatInt64Imm(0)}); | |||
| std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(segment_max, 1), | |||
| std::make_pair(segment_max, 2)}; | |||
| replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>( | |||
| std::make_pair(input_nodes, final_output)); | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -79,6 +79,20 @@ class UnsortedSegmentMinInfo : public UnsortedSegmentOpInfo { | |||
| Status ComputeReplaceGraph(const CNodePtr &cnode); | |||
| }; | |||
| class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo { | |||
| public: | |||
| UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMinCost>()) {} | |||
| ~UnsortedSegmentMaxInfo() override = default; | |||
| ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; | |||
| Status InferForwardCommunication() override { return SUCCESS; } | |||
| protected: | |||
| Status ComputeReplaceGraph(const CNodePtr &cnode); | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNSORTEDSEGMENTOP_INFO_H_ | |||
| @@ -317,7 +317,8 @@ bool IsSplittableOperator(const std::string &op_name) { | |||
| EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE, | |||
| BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2, | |||
| SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM, | |||
| UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE}; | |||
| UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, | |||
| UNSORTED_SEGMENT_MAX}; | |||
| // clang-format on | |||
| auto iter = splittable_op.find(op_name); | |||
| @@ -0,0 +1,162 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.common.api import _executor | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations.comm_ops import _VirtualDataset | |||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| grad_all = C.GradOperation(get_all=True) | |||
| class Net(nn.Cell): | |||
| def __init__(self, strategy1, strategy2, num_segments): | |||
| super(Net, self).__init__() | |||
| self.virtual_dataset = _VirtualDataset() | |||
| self.merge_op = P.UnsortedSegmentMax().shard((strategy1, strategy2)) | |||
| self.num_segments = num_segments | |||
| def construct(self, vectors, segment_ids): | |||
| predict = self.merge_op(vectors, segment_ids, self.num_segments) | |||
| return predict | |||
| class GradWrap(nn.Cell): | |||
| def __init__(self, network): | |||
| super(GradWrap, self).__init__() | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| return grad_all(self.network)(x, y) | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.loss = VirtualLoss() | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| predict = self.network(x, y) | |||
| return self.loss(predict) | |||
| def compile_graph(x, y, segments, strategy1, strategy2, auto=False): | |||
| net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments))) | |||
| net.set_auto_parallel() | |||
| net.set_train() | |||
| if auto: | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| else: | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| _executor.compile(net, x, y) | |||
| def test_UnsortedSegmentMax_model_parallel_slice_1d(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| x = Tensor(np.ones(8), ms.float32) | |||
| y = Tensor(np.ones(8), ms.int32) | |||
| num_segments = 16 | |||
| strategy1 = (8,) | |||
| strategy2 = (8,) | |||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||
| def test_UnsortedSegmentMax_model_parallel_no_slice_1d(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| x = Tensor(np.ones(8), ms.float32) | |||
| y = Tensor(np.ones(8), ms.int32) | |||
| num_segments = 16 | |||
| strategy1 = (1,) | |||
| strategy2 = (1,) | |||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||
| def test_UnsortedSegmentMax_model_parallel_index_slice_2d(): | |||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||
| x = Tensor(np.ones((4, 8)), ms.float32) | |||
| y = Tensor(np.arange(4), ms.int32) | |||
| num_segments = 4 | |||
| strategy1 = (4, 1) | |||
| strategy2 = (4,) | |||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||
| def test_UnsortedSegmentMax_model_parallel_vector_slice_2d(): | |||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||
| x = Tensor(np.ones((4, 8)), ms.float32) | |||
| y = Tensor(np.ones(4), ms.int32) | |||
| num_segments = 4 | |||
| strategy1 = (1, 4) | |||
| strategy2 = (1,) | |||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||
| def test_UnsortedSegmentMax_model_parallel_vector_slice_3d(): | |||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||
| x = Tensor(np.ones((4, 8, 8)), ms.float32) | |||
| y = Tensor(np.ones(4), ms.int32) | |||
| num_segments = 4 | |||
| strategy1 = (1, 2, 2) | |||
| strategy2 = (1,) | |||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||
| def test_UnsortedSegmentMax_model_parallel_index_vector_slice_2d(): | |||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||
| x = Tensor(np.ones((4, 8)), ms.float32) | |||
| y = Tensor(np.ones(4), ms.int32) | |||
| num_segments = 4 | |||
| strategy1 = (2, 2) | |||
| strategy2 = (2,) | |||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||
| def test_UnsortedSegmentMax_model_parallel_index_vector_slice_3d(): | |||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||
| x = Tensor(np.ones((4, 4, 8)), ms.float32) | |||
| y = Tensor(np.ones((4)), ms.int32) | |||
| num_segments = 16 | |||
| strategy1 = (2, 1, 2) | |||
| strategy2 = (2,) | |||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||
| def test_UnsortedSegmentMax_model_parallel_float16(): | |||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||
| x = Tensor(np.ones((4, 4, 8)), ms.float16) | |||
| y = Tensor(np.ones((4)), ms.int32) | |||
| num_segments = 16 | |||
| strategy1 = (2, 1, 2) | |||
| strategy2 = (2,) | |||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||
| def test_UnsortedSegmentMax_model_parallel_int32(): | |||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||
| x = Tensor(np.ones((4, 4, 8)), ms.int32) | |||
| y = Tensor(np.ones((4)), ms.int32) | |||
| num_segments = 16 | |||
| strategy1 = (2, 1, 2) | |||
| strategy2 = (2,) | |||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -11,6 +11,7 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -11,6 +11,7 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||