/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "kernel/kernel.h" #include #include #include "utils/ms_context.h" #include "utils/anf_utils.h" #include "runtime/device/ms_device_shape_transfer.h" #include "backend/common/session/anf_runtime_algorithm.h" #include "include/common/utils/anfalgo.h" #include "backend/common/optimizer/helper.h" namespace mindspore { namespace kernel { constexpr int64_t kInvalidShape = -2; void KernelMod::SetAtomicCleanNodes(const std::vector &atomic_clean_node) { atomic_clean_nodes_.resize(atomic_clean_node.size()); for (size_t i = 0; i < atomic_clean_node.size(); ++i) { atomic_clean_nodes_[i] = atomic_clean_node[i]; } } void KernelMod::InferShape() { auto node = anf_node_.lock(); MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); MS_LOG(INFO) << "InferShape start, node:" << cnode->fullname_with_scope(); GetDepndLists(cnode); auto ret = InferShapeForDefiniteOutputNode(cnode); if (ret) { return; } depend_tensor_map_.clear(); auto &inputs = cnode->inputs(); if (inputs.empty()) { MS_LOG(EXCEPTION) << "Invalid inputs."; } auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); AbstractBasePtrList args_spec_list; auto primitive = GetValueNode(inputs[0]); auto input_size = common::AnfAlgo::GetInputTensorNum(cnode); for (size_t i = 0; i < input_size; i++) { auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i); auto real_input = input_node_with_index.first; MS_EXCEPTION_IF_NULL(real_input); auto cnode_input = cnode->input(i + 1); MS_EXCEPTION_IF_NULL(cnode_input); InferShapeForNopNode(&real_input); if (depend_list_.find(i) != depend_list_.end()) { auto pre_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i); bool skip_nop_node = !context->get_param(MS_CTX_ENABLE_MINDRT); auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, i, skip_nop_node); std::vector shapes = trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second); auto host_type = common::AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second); auto out_tensor = std::make_shared(host_type, shapes); MS_EXCEPTION_IF_NULL(out_tensor); // The second parameter must be false, otherwise the device address cannot be released and allocated, and the // address size will be wrong in the dynamic shape scenario. out_tensor->set_device_address(output_addr, false); auto ret2 = depend_tensor_map_.try_emplace(i, out_tensor); if (!ret2.second) { MS_LOG(EXCEPTION) << "Insert map failed."; } out_tensor->data_sync(); // cppcheck-suppress unreadVariable auto lock = AnfUtils::GetAbstractLock(real_input.get()); auto real_abs = real_input->abstract(); if (real_abs->isa()) { real_abs->set_value(out_tensor); } else if (real_abs->isa()) { auto tuple_get_item_index = common::AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast()); auto abstract_tuple = real_abs->cast(); MS_EXCEPTION_IF_NULL(abstract_tuple); auto tuple_elements = abstract_tuple->elements()[tuple_get_item_index]; tuple_elements->set_value(out_tensor); } } common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input, i); } auto eval_result = opt::CppInferShape(primitive, args_spec_list); cnode->set_abstract(eval_result); } bool KernelMod::InferShapeForDefiniteOutputNode(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); if (!common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimShape)) { return false; } auto input_size = common::AnfAlgo::GetInputTensorNum(cnode); if (input_size != 1) { MS_LOG(EXCEPTION) << "Node only has one input: " << cnode->fullname_with_scope(); } auto cur_shape = dynamic_cast(cnode->Shape().get())->shape(); if (std::any_of(cur_shape.begin(), cur_shape.end(), [](int64_t x) { return x == kInvalidShape; })) { return false; } std::vector output_shape = {static_cast(cur_shape.size())}; mindspore::abstract::BaseShapePtr shape = std::make_shared(output_shape); // cppcheck-suppress unreadVariable auto lock = AnfUtils::GetAbstractLock(cnode.get()); auto abstract = cnode->abstract(); MS_EXCEPTION_IF_NULL(abstract); abstract->set_shape(shape); return true; } void KernelMod::InferShapeForNopNode(AnfNodePtr *input_node) { MS_EXCEPTION_IF_NULL(*input_node); if (!common::AnfAlgo::IsNopNode(*input_node) || !common::AnfAlgo::IsDynamicShape(*input_node)) { MS_LOG(INFO) << "Input node is not a nop node, no need infer."; return; } MS_LOG(INFO) << "Infer shape for nop node."; std::stack nop_road; nop_road.push(*input_node); /*lint -e716*/ while (true) { auto input_node_with_idx = common::AnfAlgo::GetPrevNodeOutput(*input_node, 0); auto in_node = input_node_with_idx.first; MS_EXCEPTION_IF_NULL(in_node); if (common::AnfAlgo::IsNopNode(in_node)) { nop_road.push(in_node); *input_node = in_node; } else { break; } } /*lint +e716*/ while (!nop_road.empty()) { auto nop_node = nop_road.top(); MS_EXCEPTION_IF_NULL(nop_node); AnfAlgo::InferShape(nop_node->cast()); nop_road.pop(); } } void KernelMod::GetDepndLists(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); if (depend_list_.size() != 0) { return; } auto ret = abstract::GetDependsFormMap(cnode); if (ret.empty()) { MS_LOG(DEBUG) << "No dynamic_shape_depends found."; return; } MS_LOG(INFO) << "Have depends."; (void)std::transform(ret.begin(), ret.end(), std::inserter(depend_list_, depend_list_.begin()), [](const int64_t &value) { return static_cast(value); }); MS_LOG(INFO) << "Init End."; } } // namespace kernel } // namespace mindspore