|
- /**
- * Copyright 2021 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "frontend/parallel/step_parallel_utils.h"
-
- #include <inttypes.h>
- #include <sys/time.h>
- #include <algorithm>
-
- #include <map>
- #include <set>
- #include <string>
- #include <unordered_map>
- #include <utility>
-
- #include "base/core_ops.h"
- #include "frontend/operator/ops.h"
- #include "frontend/optimizer/optimizer.h"
- #include "frontend/parallel/context.h"
- #include "frontend/parallel/device_manager.h"
- #include "frontend/parallel/graph_util/generate_graph.h"
- #include "frontend/parallel/graph_util/graph_info.h"
- #include "frontend/parallel/graph_util/node_info.h"
- #include "frontend/parallel/node_check.h"
- #include "ir/param_info.h"
- #include "ir/tensor.h"
- #include "utils/trace_base.h"
- #include "utils/comm_manager.h"
- #include "utils/ms_context.h"
- #include "utils/symbolic.h"
- #include "mindspore/core/utils/parallel_node_check.h"
-
- namespace mindspore {
- namespace parallel {
- bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
- if (!cnode) {
- return false;
- }
- ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>();
- MS_EXCEPTION_IF_NULL(anf_node);
- PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
- return (prim->name() == name);
- }
-
- bool IsParallelCareNode(const CNodePtr &cnode) {
- MS_EXCEPTION_IF_NULL(cnode);
- ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
- if (prim_node == nullptr) {
- return false;
- }
- PrimitivePtr prim = prim_node->value()->cast<PrimitivePtr>();
- if (prim == nullptr) {
- return false;
- }
- if (IsInParallelBlackList(prim)) {
- MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name();
- return false;
- }
- // get_next is not in the forward graph, we need mark the get_next as the forward node
- if (prim->name() == GET_NEXT || prim->name() == VIRTUAL_OUTPUT) {
- return true;
- }
- if ((prim->name() == CAST) && !cnode->has_user_data<OperatorInfo>()) {
- return false;
- }
-
- return cnode->in_forward_flag();
- }
-
- Shapes GetValueListShape(const AnfNodePtr &node) {
- Shapes shapes;
- std::vector<ValuePtr> inputs_seq;
- if (IsValueNode<ValueList>(node)) {
- inputs_seq = node->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
- } else if (IsValueNode<ValueTuple>(node)) {
- inputs_seq = node->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
- } else {
- MS_LOG(EXCEPTION) << "node is eigther ValueList or ValueTuple";
- }
- for (auto &ele : inputs_seq) {
- auto tensor = ele->cast<tensor::TensorPtr>();
- if (tensor == nullptr) {
- MS_LOG(WARNING) << "The value node is not a tensor";
- break;
- }
- auto one_shape = tensor->shape();
- shapes.push_back(one_shape);
- }
- return shapes;
- }
-
- Shapes GetNodeShape(const AnfNodePtr &node) {
- MS_EXCEPTION_IF_NULL(node);
- Shapes shapes;
- if (IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
- return GetValueListShape(node);
- }
- BaseShapePtr base_shape_ptr = node->Shape();
- if (node->isa<CNode>()) {
- auto cnode = node->cast<CNodePtr>();
- if (IsValueNode<Primitive>(cnode->input(0))) {
- PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
- MS_EXCEPTION_IF_NULL(prim);
- if (prim->name() == MAKEREF) {
- AnfNodePtr ref_node = cnode->input(1);
- auto func_graph = cnode->func_graph();
- MS_EXCEPTION_IF_NULL(ref_node);
- MS_EXCEPTION_IF_NULL(func_graph);
- return GetRefKeyNodeShape(ref_node, func_graph);
- }
- }
- if (cnode->input(0)->isa<CNode>()) {
- if (cnode->inputs().size() < 2) {
- MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2";
- }
- base_shape_ptr = cnode->input(1)->Shape();
- }
- }
- if (base_shape_ptr == nullptr) {
- MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is "
- << node->fullname_with_scope();
- }
- auto tuple_shape_ptr = dyn_cast<abstract::SequeueShape>(base_shape_ptr);
- if (tuple_shape_ptr != nullptr) {
- auto tuple_shape = tuple_shape_ptr->shape();
- for (auto &shape : tuple_shape) {
- auto each_shape = dyn_cast<abstract::Shape>(shape);
- MS_EXCEPTION_IF_NULL(each_shape);
- shapes.push_back(each_shape->shape());
- }
- } else {
- auto shape_ptr = dyn_cast<abstract::Shape>(base_shape_ptr);
- MS_EXCEPTION_IF_NULL(shape_ptr);
- shapes.push_back(shape_ptr->shape());
- }
- return shapes;
- }
-
- std::string CreateInstanceName(const CNodePtr &node, size_t index) {
- MS_EXCEPTION_IF_NULL(node);
- if (!IsValueNode<Primitive>(node->input(0))) {
- MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive";
- }
- std::string name_base = node->fullname_with_scope();
- std::string name = name_base + "_" + std::to_string(index);
- std::string instance_name = HashInstanceName(name);
- return instance_name;
- }
-
- void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
- if (new_node_input.empty()) {
- return;
- }
-
- auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
- auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
- MS_EXCEPTION_IF_NULL(prim);
-
- auto attrs = prim->attrs();
- auto iter = attrs.find(GROUP);
- if (iter != attrs.end()) {
- auto value = iter->second;
- MS_EXCEPTION_IF_NULL(value);
- if (value->isa<StringImm>()) {
- std::string hash_name = value->cast<StringImmPtr>()->value();
- MS_EXCEPTION_IF_NULL(g_device_manager);
- std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name);
- (void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name));
- }
- }
- }
-
- std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
- const CNodePtr &node) {
- OperatorArgs arg_replace_op = replace_op.second;
- ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name);
- if (pyop_instance == nullptr) {
- MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreatOpInstance failed";
- }
- OperatorParams params = arg_replace_op.second;
- if (node->inputs().size() < 2) {
- // GetNext operator dose not has input
- if (node->inputs().size() == 1) {
- return {NewValueNode(pyop_instance)};
- }
- MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2";
- }
- std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
-
- if (replace_op.first == EMBEDDING_LOOKUP) {
- replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
- }
-
- if (!params.empty()) {
- Param param_first = *(params.begin());
- int64_t first_position = param_first.second;
- if (first_position == 1) {
- replace_input.pop_back();
- }
- for (auto ¶m : params) {
- AnfNodePtr val = NewValueNode(param.first.second);
- if (val == nullptr) {
- MS_LOG(EXCEPTION) << "Failure:val is nullptr";
- }
- int64_t position = param.second;
- (void)replace_input.insert(replace_input.begin() + position, val);
- }
- } else if (replace_op.first == SYNC_BATCH_NORM) {
- for (size_t i = 2; i < node->inputs().size(); ++i) {
- replace_input.push_back(node->input(i));
- }
- }
- SetCommunicationOpGroupLabel(replace_input);
- return replace_input;
- }
- } // namespace parallel
- } // namespace mindspore
|