| @@ -23,68 +23,8 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| const std::set<std::string> BLACK_LIST = {TUPLE_GETITEM, | |||
| J, | |||
| LIST_GETITEM, | |||
| ARRAY_GETITEM, | |||
| TUPLE_SETITEM, | |||
| DEPEND, | |||
| LIST_SETITEM, | |||
| ARRAY_SETITEM, | |||
| DICT_GETITEM, | |||
| LIST_APPEND, | |||
| LIST_MAP, | |||
| LIST_REDUCE, | |||
| TUPLE_REVERSED, | |||
| TILE_SHAPE, | |||
| TUPLE_DIV, | |||
| TUPLE_TO_ARRAY, | |||
| MAKE_DICT, | |||
| MAKE_SLICE, | |||
| MAKE_RECORD, | |||
| STRING_EQUAL, | |||
| VIRTUALLOSS, | |||
| RETURN, | |||
| ENV_GETITEM, | |||
| IDENTITY, | |||
| PARTIAL, | |||
| ENVSETITEM, | |||
| ENVGETITEM, | |||
| ENVADD, | |||
| MAKEREFKEY, | |||
| MAKEREF, | |||
| GETREFKEY, | |||
| GETREFVALUE, | |||
| GETREFORIGIN, | |||
| DOT, | |||
| IM2COL, | |||
| COL2IM, | |||
| IM2COLV1, | |||
| STATESETITEM, | |||
| SCALARSUMMARY, | |||
| IMAGESUMMARY, | |||
| TENSORSUMMARY, | |||
| DEBUG, | |||
| HISTOGRAMSUMMARY, | |||
| COL2IMV1, | |||
| RESOLVE, | |||
| BROADCASTGRADIENTARGS, | |||
| INVERTPERMUTATION, | |||
| CONTROLDEPEND, | |||
| DROPOUT_GEN_MASK, | |||
| EMBED, | |||
| CREATINSTANCE, | |||
| REF_TO_EMBED, | |||
| STOP_GRADIENT, | |||
| SEND}; | |||
| const std::set<std::string> BATCH_PARALLEL_BLACK_LIST = {PACK, TENSOR_SCATTER_UPDATE, MIN_MAX_UPDATE_PER_LAYER}; | |||
| bool IsInBlackList(const PrimitivePtr &prim) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| return (BLACK_LIST.find(prim->name()) != BLACK_LIST.end()); | |||
| } | |||
| bool IsInBatchParallelBlackList(const PrimitivePtr &prim) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| return (BATCH_PARALLEL_BLACK_LIST.find(prim->name()) != BATCH_PARALLEL_BLACK_LIST.end()); | |||
| @@ -21,7 +21,6 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| bool IsInBlackList(const PrimitivePtr &prim); | |||
| bool IsInBatchParallelBlackList(const PrimitivePtr &prim); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -32,6 +32,7 @@ | |||
| #include "base/core_ops.h" | |||
| #include "utils/comm_manager.h" | |||
| #include "utils/ms_context.h" | |||
| #include "mindspore/core/utils/parallel_node_check.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -99,7 +100,7 @@ bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) { | |||
| if (IsInWhiteList(cnode)) { | |||
| return false; | |||
| } | |||
| if (IsInBlackList(prim)) { | |||
| if (IsInParallelBlackList(prim)) { | |||
| MS_LOG(INFO) << "PipelineSplit don't care node:" << prim->name(); | |||
| return false; | |||
| } | |||
| @@ -44,6 +44,7 @@ | |||
| #include "utils/comm_manager.h" | |||
| #include "utils/ms_context.h" | |||
| #include "utils/symbolic.h" | |||
| #include "mindspore/core/utils/parallel_node_check.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/util.h" | |||
| #endif | |||
| @@ -439,7 +440,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) { | |||
| if (prim == nullptr) { | |||
| return false; | |||
| } | |||
| if (IsInBlackList(prim)) { | |||
| if (IsInParallelBlackList(prim)) { | |||
| MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name(); | |||
| return false; | |||
| } | |||
| @@ -26,6 +26,7 @@ | |||
| #include "utils/profile.h" | |||
| #include "utils/ms_context.h" | |||
| #include "ir/graph_utils.h" | |||
| #include "utils/parallel_node_check.h" | |||
| // namespace to support intermediate representation definition | |||
| namespace mindspore { | |||
| @@ -91,7 +92,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { | |||
| new_node->set_inputs_value(old_node->inputs_value()); | |||
| ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); | |||
| new_node->set_scope(scope); | |||
| if (IsPrimitiveCNode(old_node, nullptr) && new_node->scope() == kDefaultScope) { | |||
| if (IsParallelCareCNode(old_node) && new_node->scope() == kDefaultScope) { | |||
| new_node->set_fullname_with_scope(old_node->fullname_with_scope()); | |||
| } | |||
| new_node->set_kernel_info(old_node->kernel_info_ptr()); | |||
| @@ -0,0 +1,57 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "utils/parallel_node_check.h" | |||
| #include <set> | |||
| #include <string> | |||
| namespace mindspore { | |||
| // clang-format off | |||
| static const std::set<std::string> PARALLEL_BLACK_LIST_ = {"tuple_getitem", "J", "list_getitem", | |||
| "array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem", | |||
| "list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array", | |||
| "make_dict", "make_slice", "make_record", "string_equal", "VirtualLoss", "return", "env_getitem", | |||
| "identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key", | |||
| "get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary", | |||
| "ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs", | |||
| "InvertPermutation", "ControlDepend", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", | |||
| "stop_gradient", "Send"}; | |||
| // clang-format on | |||
| bool IsInParallelBlackList(const PrimitivePtr &prim) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| return (PARALLEL_BLACK_LIST_.find(prim->name()) != PARALLEL_BLACK_LIST_.end()); | |||
| } | |||
| bool IsParallelCareCNode(const CNodePtr &cnode) { | |||
| if (cnode == nullptr || cnode->size() == 0) { | |||
| return false; | |||
| } | |||
| const auto &prim_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| if (prim_node == nullptr) { | |||
| return false; | |||
| } | |||
| const auto &prim = prim_node->value()->cast<PrimitivePtr>(); | |||
| if (prim == nullptr) { | |||
| return false; | |||
| } | |||
| if (IsInParallelBlackList(prim)) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_ | |||
| #define MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_ | |||
| #include "ir/primitive.h" | |||
| namespace mindspore { | |||
| bool IsInParallelBlackList(const PrimitivePtr &); | |||
| bool IsParallelCareCNode(const CNodePtr &); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_ | |||
| @@ -84,9 +84,9 @@ def test_double_star_graph(): | |||
| net.set_train() | |||
| _executor.compile(net, x, y, z, w, phase='train') | |||
| strategies = _executor._get_shard_strategy(net) | |||
| expected_strategies = {'Default/network-Net/Cast-op5': [[8, 1]], | |||
| 'Default/network-Net/Cast-op7': [[1, 8]], | |||
| 'Default/network-Net/MatMul-op6': [[8, 1], [1, 1]], | |||
| 'Default/network-Net/MatMul-op8': [[1, 1], [1, 8]], | |||
| 'Default/network-Net/MatMul-op4': [[1, 8], [8, 1]]} | |||
| expected_strategies = {'Default/network-Net/Cast-op2': [[8, 1]], | |||
| 'Default/network-Net/Cast-op4': [[1, 8]], | |||
| 'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]], | |||
| 'Default/network-Net/MatMul-op5': [[1, 1], [1, 8]], | |||
| 'Default/network-Net/MatMul-op1': [[1, 8], [8, 1]]} | |||
| assert strategies == expected_strategies | |||
| @@ -79,8 +79,8 @@ def test_two_matmul_transpose(): | |||
| net.set_train() | |||
| _executor.compile(net, x, y, b, phase='train') | |||
| strategies = _executor._get_shard_strategy(net) | |||
| expected_strategies = {'Default/network-Net/Transpose-op4': [[1, 16]], | |||
| 'Default/network-Net/Transpose-op5': [[16, 1]], | |||
| 'Default/network-Net/MatMul-op7': [[16, 1], [1, 1]], | |||
| 'Default/network-Net/MatMul-op6': [[16, 1], [1, 1]]} | |||
| expected_strategies = {'Default/network-Net/Transpose-op1': [[1, 16]], | |||
| 'Default/network-Net/Transpose-op2': [[16, 1]], | |||
| 'Default/network-Net/MatMul-op3': [[16, 1], [1, 1]], | |||
| 'Default/network-Net/MatMul-op4': [[16, 1], [1, 1]]} | |||
| assert strategies == expected_strategies | |||