From: @xiaoda_zh Reviewed-by: @stsuteng Signed-off-by: @stsutengtags/v1.1.0
| @@ -23,68 +23,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| const std::set<std::string> BLACK_LIST = {TUPLE_GETITEM, | |||||
| J, | |||||
| LIST_GETITEM, | |||||
| ARRAY_GETITEM, | |||||
| TUPLE_SETITEM, | |||||
| DEPEND, | |||||
| LIST_SETITEM, | |||||
| ARRAY_SETITEM, | |||||
| DICT_GETITEM, | |||||
| LIST_APPEND, | |||||
| LIST_MAP, | |||||
| LIST_REDUCE, | |||||
| TUPLE_REVERSED, | |||||
| TILE_SHAPE, | |||||
| TUPLE_DIV, | |||||
| TUPLE_TO_ARRAY, | |||||
| MAKE_DICT, | |||||
| MAKE_SLICE, | |||||
| MAKE_RECORD, | |||||
| STRING_EQUAL, | |||||
| VIRTUALLOSS, | |||||
| RETURN, | |||||
| ENV_GETITEM, | |||||
| IDENTITY, | |||||
| PARTIAL, | |||||
| ENVSETITEM, | |||||
| ENVGETITEM, | |||||
| ENVADD, | |||||
| MAKEREFKEY, | |||||
| MAKEREF, | |||||
| GETREFKEY, | |||||
| GETREFVALUE, | |||||
| GETREFORIGIN, | |||||
| DOT, | |||||
| IM2COL, | |||||
| COL2IM, | |||||
| IM2COLV1, | |||||
| STATESETITEM, | |||||
| SCALARSUMMARY, | |||||
| IMAGESUMMARY, | |||||
| TENSORSUMMARY, | |||||
| DEBUG, | |||||
| HISTOGRAMSUMMARY, | |||||
| COL2IMV1, | |||||
| RESOLVE, | |||||
| BROADCASTGRADIENTARGS, | |||||
| INVERTPERMUTATION, | |||||
| CONTROLDEPEND, | |||||
| DROPOUT_GEN_MASK, | |||||
| EMBED, | |||||
| CREATINSTANCE, | |||||
| REF_TO_EMBED, | |||||
| STOP_GRADIENT, | |||||
| SEND}; | |||||
| const std::set<std::string> BATCH_PARALLEL_BLACK_LIST = {PACK, TENSOR_SCATTER_UPDATE, MIN_MAX_UPDATE_PER_LAYER}; | const std::set<std::string> BATCH_PARALLEL_BLACK_LIST = {PACK, TENSOR_SCATTER_UPDATE, MIN_MAX_UPDATE_PER_LAYER}; | ||||
| bool IsInBlackList(const PrimitivePtr &prim) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| return (BLACK_LIST.find(prim->name()) != BLACK_LIST.end()); | |||||
| } | |||||
| bool IsInBatchParallelBlackList(const PrimitivePtr &prim) { | bool IsInBatchParallelBlackList(const PrimitivePtr &prim) { | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| return (BATCH_PARALLEL_BLACK_LIST.find(prim->name()) != BATCH_PARALLEL_BLACK_LIST.end()); | return (BATCH_PARALLEL_BLACK_LIST.find(prim->name()) != BATCH_PARALLEL_BLACK_LIST.end()); | ||||
| @@ -21,7 +21,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| bool IsInBlackList(const PrimitivePtr &prim); | |||||
| bool IsInBatchParallelBlackList(const PrimitivePtr &prim); | bool IsInBatchParallelBlackList(const PrimitivePtr &prim); | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -32,6 +32,7 @@ | |||||
| #include "base/core_ops.h" | #include "base/core_ops.h" | ||||
| #include "utils/comm_manager.h" | #include "utils/comm_manager.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "mindspore/core/utils/parallel_node_check.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -99,7 +100,7 @@ bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) { | |||||
| if (IsInWhiteList(cnode)) { | if (IsInWhiteList(cnode)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (IsInBlackList(prim)) { | |||||
| if (IsInParallelBlackList(prim)) { | |||||
| MS_LOG(INFO) << "PipelineSplit don't care node:" << prim->name(); | MS_LOG(INFO) << "PipelineSplit don't care node:" << prim->name(); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -44,6 +44,7 @@ | |||||
| #include "utils/comm_manager.h" | #include "utils/comm_manager.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "utils/symbolic.h" | #include "utils/symbolic.h" | ||||
| #include "mindspore/core/utils/parallel_node_check.h" | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | ||||
| #include "ps/util.h" | #include "ps/util.h" | ||||
| #endif | #endif | ||||
| @@ -439,7 +440,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) { | |||||
| if (prim == nullptr) { | if (prim == nullptr) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (IsInBlackList(prim)) { | |||||
| if (IsInParallelBlackList(prim)) { | |||||
| MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name(); | MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name(); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -26,6 +26,7 @@ | |||||
| #include "utils/profile.h" | #include "utils/profile.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "ir/graph_utils.h" | #include "ir/graph_utils.h" | ||||
| #include "utils/parallel_node_check.h" | |||||
| // namespace to support intermediate representation definition | // namespace to support intermediate representation definition | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -91,7 +92,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { | |||||
| new_node->set_inputs_value(old_node->inputs_value()); | new_node->set_inputs_value(old_node->inputs_value()); | ||||
| ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); | ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); | ||||
| new_node->set_scope(scope); | new_node->set_scope(scope); | ||||
| if (IsPrimitiveCNode(old_node, nullptr) && new_node->scope() == kDefaultScope) { | |||||
| if (IsParallelCareCNode(old_node) && new_node->scope() == kDefaultScope) { | |||||
| new_node->set_fullname_with_scope(old_node->fullname_with_scope()); | new_node->set_fullname_with_scope(old_node->fullname_with_scope()); | ||||
| } | } | ||||
| new_node->set_kernel_info(old_node->kernel_info_ptr()); | new_node->set_kernel_info(old_node->kernel_info_ptr()); | ||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "utils/parallel_node_check.h" | |||||
| #include <set> | |||||
| #include <string> | |||||
| namespace mindspore { | |||||
| // clang-format off | |||||
| static const std::set<std::string> PARALLEL_BLACK_LIST_ = {"tuple_getitem", "J", "list_getitem", | |||||
| "array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem", | |||||
| "list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array", | |||||
| "make_dict", "make_slice", "make_record", "string_equal", "VirtualLoss", "return", "env_getitem", | |||||
| "identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key", | |||||
| "get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary", | |||||
| "ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs", | |||||
| "InvertPermutation", "ControlDepend", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", | |||||
| "stop_gradient", "Send"}; | |||||
| // clang-format on | |||||
| bool IsInParallelBlackList(const PrimitivePtr &prim) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| return (PARALLEL_BLACK_LIST_.find(prim->name()) != PARALLEL_BLACK_LIST_.end()); | |||||
| } | |||||
| bool IsParallelCareCNode(const CNodePtr &cnode) { | |||||
| if (cnode == nullptr || cnode->size() == 0) { | |||||
| return false; | |||||
| } | |||||
| const auto &prim_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| if (prim_node == nullptr) { | |||||
| return false; | |||||
| } | |||||
| const auto &prim = prim_node->value()->cast<PrimitivePtr>(); | |||||
| if (prim == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if (IsInParallelBlackList(prim)) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_ | |||||
| #define MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_ | |||||
| #include "ir/primitive.h" | |||||
| namespace mindspore { | |||||
| bool IsInParallelBlackList(const PrimitivePtr &); | |||||
| bool IsParallelCareCNode(const CNodePtr &); | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_ | |||||
| @@ -84,9 +84,9 @@ def test_double_star_graph(): | |||||
| net.set_train() | net.set_train() | ||||
| _executor.compile(net, x, y, z, w, phase='train') | _executor.compile(net, x, y, z, w, phase='train') | ||||
| strategies = _executor._get_shard_strategy(net) | strategies = _executor._get_shard_strategy(net) | ||||
| expected_strategies = {'Default/network-Net/Cast-op5': [[8, 1]], | |||||
| 'Default/network-Net/Cast-op7': [[1, 8]], | |||||
| 'Default/network-Net/MatMul-op6': [[8, 1], [1, 1]], | |||||
| 'Default/network-Net/MatMul-op8': [[1, 1], [1, 8]], | |||||
| 'Default/network-Net/MatMul-op4': [[1, 8], [8, 1]]} | |||||
| expected_strategies = {'Default/network-Net/Cast-op2': [[8, 1]], | |||||
| 'Default/network-Net/Cast-op4': [[1, 8]], | |||||
| 'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]], | |||||
| 'Default/network-Net/MatMul-op5': [[1, 1], [1, 8]], | |||||
| 'Default/network-Net/MatMul-op1': [[1, 8], [8, 1]]} | |||||
| assert strategies == expected_strategies | assert strategies == expected_strategies | ||||
| @@ -79,8 +79,8 @@ def test_two_matmul_transpose(): | |||||
| net.set_train() | net.set_train() | ||||
| _executor.compile(net, x, y, b, phase='train') | _executor.compile(net, x, y, b, phase='train') | ||||
| strategies = _executor._get_shard_strategy(net) | strategies = _executor._get_shard_strategy(net) | ||||
| expected_strategies = {'Default/network-Net/Transpose-op4': [[1, 16]], | |||||
| 'Default/network-Net/Transpose-op5': [[16, 1]], | |||||
| 'Default/network-Net/MatMul-op7': [[16, 1], [1, 1]], | |||||
| 'Default/network-Net/MatMul-op6': [[16, 1], [1, 1]]} | |||||
| expected_strategies = {'Default/network-Net/Transpose-op1': [[1, 16]], | |||||
| 'Default/network-Net/Transpose-op2': [[16, 1]], | |||||
| 'Default/network-Net/MatMul-op3': [[16, 1], [1, 1]], | |||||
| 'Default/network-Net/MatMul-op4': [[16, 1], [1, 1]]} | |||||
| assert strategies == expected_strategies | assert strategies == expected_strategies | ||||