Merge pull request !5729 from lichen/add_batchnormex_optags/v1.0.0
| @@ -32,12 +32,13 @@ namespace parallel { | |||||
| class GatherV2PInfo : public OperatorInfo { | class GatherV2PInfo : public OperatorInfo { | ||||
| public: | public: | ||||
| GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | |||||
| const PrimitiveAttrs &attrs, const std::string &replace_op_name = GATHERV2) | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()), | : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()), | ||||
| axis_(0), | axis_(0), | ||||
| bias_(0), | bias_(0), | ||||
| index_offset_(0), | index_offset_(0), | ||||
| slice_size_(0) {} | |||||
| slice_size_(0), | |||||
| replace_op_name_(replace_op_name) {} | |||||
| ~GatherV2PInfo() override = default; | ~GatherV2PInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| Status InitForCostModel(const StrategyPtr &strategy) override; | Status InitForCostModel(const StrategyPtr &strategy) override; | ||||
| @@ -69,10 +70,10 @@ class GatherV2PInfo : public OperatorInfo { | |||||
| int32_t axis_; | int32_t axis_; | ||||
| std::string target_ = DEVICE; | std::string target_ = DEVICE; | ||||
| std::string replace_op_name_ = GATHERV2; | |||||
| int64_t bias_; | int64_t bias_; | ||||
| int64_t index_offset_; | int64_t index_offset_; | ||||
| int64_t slice_size_; | int64_t slice_size_; | ||||
| std::string replace_op_name_ = GATHERV2; | |||||
| Shape out_dev_matrix_shape_; | Shape out_dev_matrix_shape_; | ||||
| Group group_; | Group group_; | ||||
| bool manual_split_ = false; | bool manual_split_ = false; | ||||
| @@ -83,12 +84,9 @@ class GatherV2PInfo : public OperatorInfo { | |||||
| class SparseGatherV2Info : public GatherV2PInfo { | class SparseGatherV2Info : public GatherV2PInfo { | ||||
| public: | public: | ||||
| SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | |||||
| : GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {} | |||||
| const PrimitiveAttrs &attrs, const std::string &replace_op_name = SPARSE_GATHERV2) | |||||
| : GatherV2PInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {} | |||||
| ~SparseGatherV2Info() override = default; | ~SparseGatherV2Info() override = default; | ||||
| private: | |||||
| std::string replace_op_name_ = SPARSE_GATHERV2; | |||||
| }; | }; | ||||
| class EmbeddingLookupInfo : public GatherV2PInfo { | class EmbeddingLookupInfo : public GatherV2PInfo { | ||||
| @@ -197,6 +197,7 @@ constexpr char ARGMAXWITHVALUE[] = "ArgMaxWithValue"; | |||||
| constexpr char ARGMINWITHVALUE[] = "ArgMinWithValue"; | constexpr char ARGMINWITHVALUE[] = "ArgMinWithValue"; | ||||
| constexpr char CONV2D[] = "Conv2D"; | constexpr char CONV2D[] = "Conv2D"; | ||||
| constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm"; | constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm"; | ||||
| constexpr char FUSE_BATCH_NORM_EX[] = "FusedBatchNormEx"; | |||||
| constexpr char BATCH_NORM[] = "BatchNorm"; | constexpr char BATCH_NORM[] = "BatchNorm"; | ||||
| constexpr char LAYER_NORM[] = "LayerNorm"; | constexpr char LAYER_NORM[] = "LayerNorm"; | ||||
| constexpr char POOLING[] = "Pooling"; | constexpr char POOLING[] = "Pooling"; | ||||
| @@ -263,7 +263,7 @@ bool IsSplittableOperator(const std::string &op_name) { | |||||
| LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT, | LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT, | ||||
| STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT, | STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT, | ||||
| SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, | SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, | ||||
| EMBEDDING_LOOKUP}; | |||||
| EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX}; | |||||
| // clang-format on | // clang-format on | ||||
| auto iter = splittable_op.find(op_name); | auto iter = splittable_op.find(op_name); | ||||
| @@ -570,8 +570,7 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st | |||||
| MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2"; | MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2"; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)}; | std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)}; | ||||
| auto prim = GetValueNode<PrimitivePtr>(node->input(0)); | |||||
| if (prim->name() == EMBEDDING_LOOKUP) { | |||||
| if (replace_op.first == EMBEDDING_LOOKUP) { | |||||
| replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; | replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; | ||||
| } | } | ||||
| if (!params.empty()) { | if (!params.empty()) { | ||||
| @@ -40,7 +40,7 @@ CommManager &CommManager::GetInstance() noexcept { | |||||
| #define HCCL_RUN_CHECK(op_name, group, op) \ | #define HCCL_RUN_CHECK(op_name, group, op) \ | ||||
| do { \ | do { \ | ||||
| auto hccl_result = (op); \ | auto hccl_result = (op); \ | ||||
| if (hccl_result != tagHcclResult::HCCL_SUCCESS) { \ | |||||
| if (hccl_result != 0) { \ | |||||
| MS_LOG(ERROR) << op_name << " failed: #" << group << "#"; \ | MS_LOG(ERROR) << op_name << " failed: #" << group << "#"; \ | ||||
| return false; \ | return false; \ | ||||
| } \ | } \ | ||||
| @@ -0,0 +1,76 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import numpy as np | |||||
| import mindspore as ms | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.common.api import _executor | |||||
| from mindspore.common.parameter import Parameter | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.nn as nn | |||||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||||
| grad_all = C.GradOperation(get_all=True) | |||||
| class NetWithLoss(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(NetWithLoss, self).__init__() | |||||
| self.loss = VirtualLoss() | |||||
| self.network = network | |||||
| def construct(self, x, y, b): | |||||
| predict = self.network(x, y, b) | |||||
| return self.loss(predict) | |||||
| class GradWrap(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(GradWrap, self).__init__() | |||||
| self.network = network | |||||
| def construct(self, x, y, b): | |||||
| return grad_all(self.network)(x, y, b) | |||||
| # model_parallel test | |||||
| def test_two_matmul_batchnorm_ex(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, strategy1, strategy2): | |||||
| super().__init__() | |||||
| self.matmul1 = P.MatMul().set_strategy(strategy1) | |||||
| self.norm = P.FusedBatchNormEx() | |||||
| self.gamma = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="gamma") | |||||
| self.beta = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="beta") | |||||
| self.mean = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="mean") | |||||
| self.var = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="var") | |||||
| self.matmul2 = P.MatMul().set_strategy(strategy2) | |||||
| def construct(self, x, y, b): | |||||
| out = self.matmul1(x, y) | |||||
| out = self.norm(out, self.gamma, self.beta, self.mean, self.var)[0] | |||||
| out = self.matmul2(out, b) | |||||
| return out | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8) | |||||
| strategy1 = ((4, 2), (2, 1)) | |||||
| strategy2 = ((1, 8), (8, 1)) | |||||
| net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) | |||||
| net.set_auto_parallel() | |||||
| x = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||||
| y = Tensor(np.ones([32, 64]), dtype=ms.float32) | |||||
| b = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| _executor.compile(net, x, y, b) | |||||
| @@ -13,7 +13,6 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| @@ -158,18 +157,6 @@ def test_gatherv2_semi_auto7(): | |||||
| _executor.compile(net, x, y) | _executor.compile(net, x, y) | ||||
| def test_gatherv2_semi_auto8(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||||
| strategy1 = ((8,), (1, 1)) | |||||
| strategy2 = ((4, 2), (4, 2)) | |||||
| net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) | |||||
| net.set_auto_parallel() | |||||
| x = Tensor(np.ones([64]), dtype=ms.float32) | |||||
| y = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| _executor.compile(net, x, y) | |||||
| def test_gatherv2_auto0(): | def test_gatherv2_auto0(): | ||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") | context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") | ||||
| net = GradWrap(NetWithLoss(Net(0))) | net = GradWrap(NetWithLoss(Net(0))) | ||||
| @@ -188,7 +175,6 @@ def test_gatherv2_auto1(): | |||||
| _executor.compile(net, x, y) | _executor.compile(net, x, y) | ||||
| @pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") | |||||
| def test_gatherv2_cpu0(): | def test_gatherv2_cpu0(): | ||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | ||||
| strategy1 = ((8, 1), (1, 1)) | strategy1 = ((8, 1), (1, 1)) | ||||
| @@ -201,7 +187,6 @@ def test_gatherv2_cpu0(): | |||||
| _executor.compile(net, x, y) | _executor.compile(net, x, y) | ||||
| @pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") | |||||
| def test_gatherv2_cpu1(): | def test_gatherv2_cpu1(): | ||||
| context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") | ||||
| strategy1 = ((16, 1), (1, 1)) | strategy1 = ((16, 1), (1, 1)) | ||||
| @@ -214,7 +199,6 @@ def test_gatherv2_cpu1(): | |||||
| _executor.compile(net, x, y) | _executor.compile(net, x, y) | ||||
| @pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") | |||||
| def test_gatherv2_cpu2(): | def test_gatherv2_cpu2(): | ||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | ||||
| strategy1 = ((1, 8), (1, 1)) | strategy1 = ((1, 8), (1, 1)) | ||||