Merge pull request !2160 from xianwz/master_graph_kerneltags/v0.5.0-beta
| @@ -13,3 +13,6 @@ | |||
| [submodule "graphengine"] | |||
| path = graphengine | |||
| url = https://gitee.com/mindspore/graphengine.git | |||
| [submodule "akg"] | |||
| path = akg | |||
| url = https://gitee.com/mindspore/akg.git | |||
| @@ -86,10 +86,14 @@ if (ENABLE_GE OR ENABLE_D OR ENABLE_TESTCASES) | |||
| include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain) | |||
| endif() | |||
| if (ENABLE_AKG AND ENABLE_D) | |||
| add_subdirectory("${CMAKE_SOURCE_DIR}/akg") | |||
| endif() | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") | |||
| add_subdirectory(mindspore/ccsrc) | |||
| if (ENABLE_TESTCASES) | |||
| add_subdirectory(tests) | |||
| endif() | |||
| include(cmake/package.cmake) | |||
| include(cmake/package.cmake) | |||
| @@ -0,0 +1 @@ | |||
| Subproject commit c460176523d039c8995f1d71089753725ebc0792 | |||
| @@ -246,6 +246,9 @@ checkopts "$@" | |||
| echo "---------------- mindspore: build start ----------------" | |||
| mkdir -pv "${BUILD_PATH}/package/mindspore/lib" | |||
| git submodule update --init graphengine | |||
| if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then | |||
| git submodule update --init --recursive akg | |||
| fi | |||
| build_exit() | |||
| { | |||
| @@ -308,7 +311,7 @@ build_mindspore() | |||
| if [[ "X$USE_GLOG" = "Xon" ]]; then | |||
| CMAKE_ARGS="${CMAKE_ARGS} -DUSE_GLOG=ON" | |||
| fi | |||
| if [[ "X$ENABLE_AKG" = "Xon" ]]; then | |||
| if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then | |||
| CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON" | |||
| fi | |||
| echo "${CMAKE_ARGS}" | |||
| @@ -236,6 +236,16 @@ if (ENABLE_GPU) | |||
| endif () | |||
| endif () | |||
| if (ENABLE_D AND ENABLE_AKG) | |||
| set (AKG_PATH ${CMAKE_SOURCE_DIR}/build/mindspore/akg) | |||
| install( | |||
| DIRECTORY | |||
| ${AKG_PATH}/akg | |||
| DESTINATION ${INSTALL_PY_DIR}/.. | |||
| COMPONENT mindspore | |||
| ) | |||
| endif () | |||
| if (EXISTS ${CMAKE_SOURCE_DIR}/mindspore/dataset) | |||
| install( | |||
| DIRECTORY ${CMAKE_SOURCE_DIR}/mindspore/dataset | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -0,0 +1,35 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Providing akg compile with json""" | |||
| import sys | |||
| def run_compiler(op_json): | |||
| """ | |||
| Run AKG compiler to compile op with subprocess, if this process of | |||
| compilation failed, an exception will be raised | |||
| Args: | |||
| op_json (str): json string of the op | |||
| Returns: | |||
| None | |||
| """ | |||
| p = __import__("akg", globals(), locals(), ['ms'], 0) | |||
| func = getattr(p.ms, "compilewithjson") | |||
| res = func(op_json) | |||
| if not res: | |||
| raise ValueError("Compile error") | |||
| if __name__ == "__main__": | |||
| run_compiler(sys.argv[1]) | |||
| @@ -0,0 +1,71 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Providing multi process compile with json""" | |||
| import os | |||
| import subprocess | |||
| import sys | |||
| from multiprocessing import Pool, cpu_count | |||
| def _compile_akg_task(*json_strs): | |||
| """ | |||
| compile func called in single process | |||
| Parameters: | |||
| json_strs: list. List contains multiple kernel infos, suitable for json compile api. | |||
| """ | |||
| akg_compiler = os.path.join(os.path.split( | |||
| os.path.realpath(__file__))[0], "compiler.py") | |||
| for json_str in json_strs: | |||
| res = subprocess.run( | |||
| [sys.executable, akg_compiler, json_str], text=True) | |||
| if res.returncode != 0: | |||
| raise ValueError("Failed, args: {}!".format(json_str)) | |||
| def compile_akg_kernel_parallel(json_infos, process, waitime): | |||
| """ | |||
| compile kernel use multi processes | |||
| Parameters: | |||
| json_infos: list. list contain kernel info(task id and json str) | |||
| process: int. processes num | |||
| waittime: int. max time the function blocked | |||
| Returns: | |||
| True for all compile success, False for some failed. | |||
| """ | |||
| if not isinstance(json_infos, list): | |||
| raise ValueError("json_infos must be a list") | |||
| if not isinstance(process, int): | |||
| raise ValueError("process must be a num") | |||
| if not isinstance(waitime, int): | |||
| raise ValueError("waittime must be a num") | |||
| if process == 0 and json_infos: | |||
| process = 1 | |||
| cpu_proc_num = cpu_count() | |||
| max_proc_num = 16 | |||
| process = min([cpu_proc_num, max_proc_num, process]) | |||
| args = [[] for _ in range(process)] | |||
| for p, info in enumerate(json_infos): | |||
| args[p % process].append(info) | |||
| with Pool(processes=process) as pool: | |||
| res = pool.starmap_async(_compile_akg_task, args) | |||
| res.get(timeout=waitime) | |||
| return True | |||
| @@ -1,107 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Providing multi process compile with json""" | |||
| import json | |||
| import math | |||
| import os | |||
| import subprocess | |||
| import sys | |||
| from multiprocessing import Pool | |||
| def _compiletask(platform, *jsons): | |||
| """ | |||
| compile func called in single process | |||
| Parameters: | |||
| platform: str. AKG platform or TBE platform | |||
| *jsons: str. json str contain kernel info, suitable for json compile | |||
| api | |||
| """ | |||
| if platform == "AKG": | |||
| p = __import__("_akg", globals(), locals(), ['ms'], 0) | |||
| func = getattr(p.ms, "compilewithjson") | |||
| for json_item in jsons: | |||
| res = func(json_item) | |||
| if not res: | |||
| raise ValueError("Compile error") | |||
| if platform == "TBE": | |||
| tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "tbe_compiler", "compiler.py") | |||
| for json_item in jsons: | |||
| res = subprocess.run([sys.executable, tbe_compiler], input=json_item, text=True) | |||
| if res.returncode != 0: | |||
| raise ValueError("Tbe compile error") | |||
| def compilekernelparallel(jsons, process, waitime): | |||
| """ | |||
| compile kernel use multi processes | |||
| Parameters: | |||
| jsons: list. json str list contain kernel info | |||
| process: int. processes num | |||
| waittime: int. max time the function blocked | |||
| """ | |||
| if not isinstance(jsons, list): | |||
| raise ValueError("jsons must be a list") | |||
| if not isinstance(process, int): | |||
| raise ValueError("process must be a num") | |||
| if not isinstance(waitime, int): | |||
| raise ValueError("waittime must be a num") | |||
| jsons_akg = [] | |||
| jsons_tbe = [] | |||
| for json_ in jsons: | |||
| j = json.loads(json_) | |||
| if j["platform"] == "TBE": | |||
| jsons_tbe.append(json_) | |||
| continue | |||
| if j["platform"] == "AKG": | |||
| jsons_akg.append(json_) | |||
| continue | |||
| raise RuntimeError( | |||
| "not support this platform {0}".format(j["platform"])) | |||
| if jsons_akg: | |||
| process_akg = math.floor(len(jsons)/len(jsons_akg)*process) | |||
| else: | |||
| process_akg = 0 | |||
| if process_akg == 0 and jsons_akg: | |||
| process_akg = 1 | |||
| process_tbe = process-process_akg | |||
| if process_tbe == 0 and jsons_tbe: | |||
| process_tbe = 1 | |||
| raise RuntimeWarning("we add a process for compile more operator") | |||
| args = [[] for _ in range(process_akg+process_tbe)] | |||
| args_lens = len(args) | |||
| for p in range(args_lens): | |||
| if p < process_tbe: | |||
| args[p].append("TBE") | |||
| else: | |||
| args[p].append("AKG") | |||
| jsons_tbe_lens = len(jsons_tbe) | |||
| for p in range(jsons_tbe_lens): | |||
| args[p % process_tbe].append(jsons_tbe[p]) | |||
| jsons_akg_lens = len(jsons_akg) | |||
| for p in range(jsons_akg_lens): | |||
| args[process-p % process_akg-1].append(jsons_akg[p]) | |||
| for p in range(args_lens): | |||
| args[p] = tuple(args[p]) | |||
| with Pool(processes=process) as pool: | |||
| res = pool.starmap_async(_compiletask, args) | |||
| res.get(timeout=waitime) | |||
| return True | |||
| @@ -39,7 +39,7 @@ if(ENABLE_GPU) | |||
| "device/gpu/*.cu" | |||
| "kernel/gpu/*.cu" | |||
| "kernel/akg/gpu/*.cc" | |||
| "kernel/akg/akgkernelbuild.cc" | |||
| "kernel/akg/akg_kernel_build.cc" | |||
| "kernel/akg/akg_kernel_attrs_process.cc" | |||
| ) | |||
| @@ -428,6 +428,10 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s | |||
| auto temp_shape = shape; | |||
| std::vector<size_t> device_shape; | |||
| if (format == kOpFormat_FRAC_NZ) { | |||
| if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) { | |||
| // For [1] and [1024] shape we can trait it as NZ shape | |||
| return shape; | |||
| } | |||
| if (shape.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size(); | |||
| } else { | |||
| @@ -111,9 +111,15 @@ void DumpGlobalInfoEntry(const FuncGraphPtr &graph, std::ostringstream &buffer) | |||
| } | |||
| buffer << "#IR entry : @" << graph->ToString() << "." << graph->debug_info()->get_id() << std::endl; | |||
| buffer << "#flags :" << std::endl; | |||
| for (const auto &flag : graph->flags()) { | |||
| buffer << flag.first << " : " << flag.second << std::endl; | |||
| buffer << "#attrs :" << std::endl; | |||
| for (const auto &attr : graph->attrs()) { | |||
| buffer << attr.first << " : "; | |||
| if (attr.second->isa<BoolImm>()) { | |||
| buffer << GetValue<bool>(attr.second); | |||
| } else if (attr.second->isa<StringImm>()) { | |||
| buffer << GetValue<std::string>(attr.second); | |||
| } | |||
| buffer << std::endl; | |||
| } | |||
| } | |||
| @@ -417,10 +423,16 @@ void DumpSubgraph(const OrderedMap<FuncGraphPtr, std::shared_ptr<SubGraphIRInfo> | |||
| fout << std::endl; | |||
| for (const auto &sg : *sub_graphs) { | |||
| fout << "subgraph flag:" << std::endl; | |||
| fout << "subgraph attr:" << std::endl; | |||
| MS_EXCEPTION_IF_NULL(sg.first); | |||
| for (const auto &flag : sg.first->flags()) { | |||
| fout << flag.first << " : " << flag.second << std::endl; | |||
| for (const auto &attr : sg.first->attrs()) { | |||
| fout << attr.first << " : "; | |||
| if (attr.second->isa<BoolImm>()) { | |||
| fout << GetValue<bool>(attr.second); | |||
| } else if (attr.second->isa<StringImm>()) { | |||
| fout << GetValue<std::string>(attr.second); | |||
| } | |||
| fout << std::endl; | |||
| } | |||
| fout << "subgraph @" << sg.first->ToString() << "."; | |||
| fout << sg.first->debug_info()->get_id() << "("; | |||
| @@ -548,9 +548,15 @@ void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr<session::KernelGr | |||
| for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { | |||
| cur_cnode_ptr = cnode_ptr_list[i]; | |||
| MS_EXCEPTION_IF_NULL(cur_cnode_ptr); | |||
| ValuePtr value_ptr = nullptr; | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); | |||
| if (primitive != nullptr) { | |||
| value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); | |||
| } else { | |||
| auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cur_cnode_ptr); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| value_ptr = func_graph->get_attr(kStreamNeedActivedFirst); | |||
| } | |||
| if (value_ptr == nullptr) { | |||
| continue; | |||
| } | |||
| @@ -26,10 +26,12 @@ | |||
| #include "kernel/kernel.h" | |||
| #include "kernel/tbe/tbe_kernel_build.h" | |||
| #include "kernel/tbe/tbe_kernel_parallel_build.h" | |||
| #include "kernel/akg/ascend/akg_ascend_kernel_build.h" | |||
| #include "kernel/aicpu/aicpu_kernel_build.h" | |||
| #include "kernel/hccl/hccl_kernel_build.h" | |||
| #include "kernel/rts/rt_kernel_build.h" | |||
| #include "kernel/tbe/tbe_utils.h" | |||
| #include "kernel/common_utils.h" | |||
| #include "operator/ops.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "./common.h" | |||
| @@ -91,6 +93,7 @@ static bool KernelPreBuildParallelCompile(const mindspore::session::KernelGraph | |||
| static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph_ptr); | |||
| std::vector<AnfNodePtr> tbe_nodes; | |||
| std::vector<AnfNodePtr> akg_nodes; | |||
| std::vector<AnfNodePtr> other_nodes; | |||
| for (const auto &anf_node : kernel_graph_ptr->execution_order()) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| @@ -105,19 +108,26 @@ static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *ke | |||
| } | |||
| break; | |||
| } | |||
| case KernelType::AKG_KERNEL: { | |||
| akg_nodes.push_back(anf_node); | |||
| break; | |||
| } | |||
| default: { | |||
| other_nodes.push_back(anf_node); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| bool ret = kernel::TbeOpParallelBuild(tbe_nodes); | |||
| bool tbe_ret = kernel::TbeOpParallelBuild(tbe_nodes); | |||
| bool akg_ret = kernel::AkgAscendKernelParallelBuild(akg_nodes); | |||
| auto bin_map = kernel::tbe::KernelMeta::GetInstance(); | |||
| (void)bin_map->ReadIndex(kernel::kCceKernelMeta); | |||
| for (const auto &anf_node : other_nodes) { | |||
| kernel::KernelModPtr kernel_mod_ptr = SerialCompileImpl(anf_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod_ptr); | |||
| AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); | |||
| } | |||
| return ret; | |||
| return tbe_ret && akg_ret; | |||
| } | |||
| static std::vector<int> CalCleanZerosSize(const CNodePtr &pre_node) { | |||
| @@ -234,7 +244,7 @@ void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) { | |||
| for (const auto &anf_node : kernel_graph->execution_order()) { | |||
| std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); | |||
| if (apply_function_name == prim::kPrimMaxPoolGrad->name() && | |||
| AnfAlgo::GetKernelType(anf_node) == KernelType::AUTO_DIFF_KERNEL) { | |||
| AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { | |||
| auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName); | |||
| MS_EXCEPTION_IF_NULL(clear_zero_prim); | |||
| auto new_value_node = NewValueNode(clear_zero_prim); | |||
| @@ -15,16 +15,27 @@ | |||
| */ | |||
| #include "device/ascend/kernel_select_ascend.h" | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <algorithm> | |||
| #include <map> | |||
| #include "kernel/oplib/oplib.h" | |||
| #include "kernel/kernel_query.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include "common/utils.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "operator/ops.h" | |||
| #include "ir/func_graph.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "device/kernel_info.h" | |||
| #include "kernel/common_utils.h" | |||
| #include "kernel/kernel_query.h" | |||
| #include "kernel/oplib/oplib.h" | |||
| #include "kernel/kernel_build_info.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| @@ -121,12 +132,23 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons | |||
| } | |||
| auto pri_match_format = GetPriorityMatchFormat(kernel_node); | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||
| auto input_anf_node = kernel_node->input(input_index + 1); | |||
| // we do not take ValueNode into consideration in graph kernel. | |||
| if (kernel_build_info.kernel_type() == KernelType::AKG_KERNEL) { | |||
| if (input_anf_node->isa<ValueNode>() && AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) { | |||
| continue; | |||
| } | |||
| } | |||
| auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore; | |||
| if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { | |||
| (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score; | |||
| } | |||
| if (kernel_build_info.GetInputDeviceType(input_index) == | |||
| AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)) { | |||
| // we match output fix precision first. | |||
| auto prev_device_type = AnfAlgo::GetPrevNodeOutputPrecision(kernel_node, input_index); | |||
| if (prev_device_type == kTypeUnknown) { | |||
| prev_device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index); | |||
| } | |||
| if (kernel_build_info.GetInputDeviceType(input_index) == prev_device_type) { | |||
| (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score; | |||
| } | |||
| if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { | |||
| @@ -146,41 +168,6 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons | |||
| } | |||
| } | |||
| void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||
| auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); | |||
| MS_EXCEPTION_IF_NULL(input_kernel_node); | |||
| auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); | |||
| MS_EXCEPTION_IF_NULL(input_with_index.first); | |||
| auto real_input_node = input_with_index.first; | |||
| if (real_input_node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder = | |||
| std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| bool is_ref = false; | |||
| auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE); | |||
| if (op_info != nullptr) { | |||
| is_ref = op_info->is_ref(); | |||
| } | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->execution_mode() == kPynativeMode && | |||
| AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { | |||
| continue; | |||
| } | |||
| // we set special device info of a input tensor. | |||
| if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { | |||
| std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)}; | |||
| builder->SetOutputsFormat(output_format); | |||
| std::vector<TypeId> output_type = {AnfAlgo::GetInputDeviceDataType(kernel_node, input_index)}; | |||
| builder->SetOutputsDeviceType(output_type); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); | |||
| } | |||
| } | |||
| } | |||
| void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *support_index) { | |||
| MS_EXCEPTION_IF_NULL(support_index); | |||
| int index = kUnSupportMixedDataTypeIndex; | |||
| @@ -467,6 +454,51 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis | |||
| } | |||
| } // namespace | |||
| void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||
| auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); | |||
| MS_EXCEPTION_IF_NULL(input_kernel_node); | |||
| auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); | |||
| MS_EXCEPTION_IF_NULL(input_with_index.first); | |||
| auto real_input_node = input_with_index.first; | |||
| if (real_input_node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) { | |||
| continue; | |||
| } | |||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| if (IsValueNode<tensor::Tensor>(input_kernel_node) && | |||
| AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { | |||
| std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)}; | |||
| builder->SetOutputsFormat(output_format); | |||
| std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; | |||
| builder->SetOutputsDeviceType(output_type); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); | |||
| continue; | |||
| } | |||
| // we set special device info of a input tensor. | |||
| bool is_ref = false; | |||
| auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE); | |||
| if (op_info != nullptr) { | |||
| is_ref = op_info->is_ref(); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | |||
| if (MsContext::GetInstance()->execution_mode() == kPynativeMode && | |||
| AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { | |||
| std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)}; | |||
| builder->SetOutputsFormat(output_format); | |||
| std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; | |||
| builder->SetOutputsDeviceType(output_type); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); | |||
| } | |||
| } | |||
| } | |||
| KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, | |||
| const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| @@ -498,11 +530,17 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, | |||
| return select_status; | |||
| } | |||
| KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { | |||
| KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list; | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| kernel::KernelQuery(kernel_node, &kernel_info_list); | |||
| if (AnfAlgo::IsGraphKernel(kernel_node)) { | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(kernel_node->input(kAnfPrimitiveIndex)); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| SelectGraphKernelInfo(kernel_node, func_graph); | |||
| return kStatusAllMatched; | |||
| } | |||
| kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type); | |||
| auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); | |||
| // If aicore not find valid kernel info reloading aicpu kernel info list to find it | |||
| if (select_status == kNoMatched) { | |||
| @@ -27,7 +27,10 @@ enum KernelSelectStatus { | |||
| kStatusReducePrecision = 1, | |||
| kStatusRaisePrecision = 2, | |||
| }; | |||
| KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node); | |||
| KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, | |||
| KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); | |||
| void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node); | |||
| void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph); | |||
| } // namespace ascend | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,516 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "device/ascend/kernel_select_ascend.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "device/kernel_info.h" | |||
| #include "ir/func_graph.h" | |||
| #include "kernel/common_utils.h" | |||
| #include "kernel/kernel_query.h" | |||
| #include "kernel/kernel_build_info.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| TypeId GetPrimitivePrecision(const CNodePtr &cnode) { | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| TypeId except_type = kTypeUnknown; | |||
| if (primitive->GetAttr(kAttrFixPrecision) != nullptr) { | |||
| auto strExceptDtype = GetValue<std::string>(primitive->GetAttr(kAttrFixPrecision)); | |||
| if (strExceptDtype == "float16") { | |||
| except_type = kNumberTypeFloat16; | |||
| } else if (strExceptDtype == "float32") { | |||
| except_type = kNumberTypeFloat32; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "The fix precision must be float16 or float32, but got" << strExceptDtype; | |||
| } | |||
| } | |||
| return except_type; | |||
| } | |||
| void ResetKernelBuildInfo(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); | |||
| MS_EXCEPTION_IF_NULL(input_kernel_node); | |||
| auto kernel_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); | |||
| if (!kernel::IsWeightBoundary(kernel_with_index.first)) { | |||
| continue; | |||
| } | |||
| // reset format and dtype. | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| builder.SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT}); | |||
| builder.SetOutputsDeviceType(std::vector<TypeId>{kTypeUnknown}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_kernel_node.get()); | |||
| } | |||
| } | |||
| void UpdateKernelInfo(const std::vector<AnfNodePtr> &node_list) { | |||
| for (size_t i = 0; i < node_list.size(); ++i) { | |||
| // select nodes in subgraph. | |||
| auto anf_node = node_list[i]; | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto fix_precision_type = GetPrimitivePrecision(cnode); | |||
| if (fix_precision_type != kTypeUnknown) { | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | |||
| kernel::KernelQuery(cnode, &kernel_info_list, KernelType::AKG_KERNEL); | |||
| for (size_t index = 0; index < kernel_info_list.size(); ++index) | |||
| // only math the first input | |||
| if (kernel_info_list[index]->GetInputDeviceType(0) == fix_precision_type && | |||
| kernel_info_list[index]->GetInputFormat(0) == AnfAlgo::GetPrevNodeOutputFormat(cnode, 0) && | |||
| AnfAlgo::GetInputDeviceDataType(cnode, 0) != fix_precision_type) { | |||
| auto selected_kernel_info_ptr = kernel_info_list[index]; | |||
| ResetKernelBuildInfo(cnode); | |||
| AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, cnode.get()); | |||
| SetTensorDeviceInfo(*selected_kernel_info_ptr, cnode); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| bool CanConvertDefaultShapeToNZ(const std::vector<size_t> &shape) { | |||
| for (size_t i = 1; i <= shape.size(); ++i) { | |||
| if (i > 2) { | |||
| break; | |||
| } | |||
| if (shape[shape.size() - i] != 1 && shape[shape.size() - i] % kCubeSize != 0) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| std::vector<int> DefaultToFracNZAxis(const std::vector<size_t> &ori_shape, const std::vector<int> &axis) { | |||
| std::vector<int> frac_nz_axis = axis; | |||
| auto shape_len = ori_shape.size(); | |||
| for (size_t i = 0; i < axis.size(); ++i) { | |||
| auto axis_idx = (frac_nz_axis[i] + shape_len) % shape_len; | |||
| if (axis_idx == shape_len - 1) { | |||
| frac_nz_axis[i] = axis_idx - 1; | |||
| frac_nz_axis.push_back(axis_idx + 2); | |||
| } else if (axis_idx == shape_len - 2) { | |||
| frac_nz_axis[i] = axis_idx + 1; | |||
| frac_nz_axis.push_back(axis_idx + 2); | |||
| } else { | |||
| frac_nz_axis[i] = axis_idx; | |||
| } | |||
| } | |||
| return frac_nz_axis; | |||
| } | |||
| std::vector<size_t> GetReducedFracNZShape(const std::vector<size_t> &ori_shape, const std::vector<int> &axis, | |||
| bool keep_dims) { | |||
| std::vector<size_t> result; | |||
| std::set<size_t> positive_idx; | |||
| for (const auto &a : axis) { | |||
| positive_idx.insert(a >= 0 ? a : ori_shape.size() + a); | |||
| } | |||
| for (size_t i = 0; i < ori_shape.size(); ++i) { | |||
| if (positive_idx.count(i) == 0) { | |||
| result.push_back(ori_shape[i]); | |||
| } else if (keep_dims) { | |||
| result.push_back(1); | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| void UpdateFracNZReduceOp(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0); | |||
| if (input_format == kOpFormat_FRAC_NZ) { | |||
| // Clone primitive to modify it | |||
| auto prim = GetCNodePrimitive(cnode); | |||
| auto new_prim = std::make_shared<Primitive>(*prim); | |||
| auto new_prim_node = NewValueNode(new_prim); | |||
| cnode->set_input(0, new_prim_node); | |||
| auto axis_value = new_prim->GetAttr(kAttrAxis); | |||
| std::vector<int> default_axis; | |||
| if (axis_value->isa<ValueList>()) { | |||
| auto value_list = dyn_cast<ValueList>(axis_value); | |||
| for (const auto &item : value_list->value()) { | |||
| if (item->isa<Int32Imm>()) { | |||
| default_axis.push_back(GetValue<int32_t>(item)); | |||
| } | |||
| } | |||
| } else if (axis_value->isa<ValueTuple>()) { | |||
| auto value_tuple = dyn_cast<ValueTuple>(axis_value); | |||
| for (const auto &item : value_tuple->value()) { | |||
| if (item->isa<Int32Imm>()) { | |||
| default_axis.push_back(GetValue<int32_t>(item)); | |||
| } | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "Axis attr type is not correct!"; | |||
| } | |||
| auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); | |||
| std::vector<int> frac_nz_axis = DefaultToFracNZAxis(infer_shape, default_axis); | |||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<std::vector<int>>(frac_nz_axis), cnode); | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(cnode, 0); | |||
| if (output_shape.size() == 1) { | |||
| AnfAlgo::SetNodeAttr(kAttrOutputDefault, MakeValue<bool>(true), cnode); | |||
| } | |||
| } | |||
| } | |||
| void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, bool *use_same_format) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(default_format); | |||
| MS_EXCEPTION_IF_NULL(use_same_format); | |||
| std::unordered_map<std::string, size_t> all_input_formats; | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; | |||
| MS_EXCEPTION_IF_NULL(input_kernel_node); | |||
| if (!input_kernel_node->isa<Parameter>()) { | |||
| auto pre_format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i); | |||
| ++all_input_formats[pre_format]; | |||
| continue; | |||
| } | |||
| auto para = input_kernel_node->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(para); | |||
| if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) { | |||
| auto pre_format = AnfAlgo::GetOutputFormat(para, 0); | |||
| ++all_input_formats[pre_format]; | |||
| continue; | |||
| } | |||
| *use_same_format = false; | |||
| } | |||
| if (all_input_formats.empty()) { | |||
| // all inputs are parameter. | |||
| *default_format = kOpFormat_NC1HWC0; | |||
| } else { | |||
| std::vector<std::pair<std::string, size_t>> pairs; | |||
| for (auto iter = all_input_formats.begin(); iter != all_input_formats.end(); ++iter) { | |||
| pairs.push_back(std::make_pair(iter->first, iter->second)); | |||
| } | |||
| auto cmp_func = [](const std::pair<std::string, size_t> &a, const std::pair<std::string, size_t> &b) { | |||
| if (a.second != b.second) { | |||
| return a.second > b.second; | |||
| } else if (a.first == kOpFormat_DEFAULT) { | |||
| return a.second + 1 > b.second; | |||
| } else if (b.first == kOpFormat_DEFAULT) { | |||
| return a.second > b.second + 1; | |||
| } | |||
| return a.second > b.second; | |||
| }; | |||
| std::sort(pairs.begin(), pairs.end(), cmp_func); | |||
| *default_format = pairs.begin()->first; | |||
| } | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; | |||
| MS_EXCEPTION_IF_NULL(input_kernel_node); | |||
| if (!input_kernel_node->isa<Parameter>() || | |||
| AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) != kTypeUnknown) { | |||
| continue; | |||
| } | |||
| auto weight_infer_shape = AnfAlgo::GetOutputInferShape(input_kernel_node, 0); | |||
| if (weight_infer_shape.size() < 2 && *default_format == kOpFormat_FRAC_NZ) { | |||
| *default_format = kOpFormat_DEFAULT; | |||
| *use_same_format = true; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| void UpdateGraphKernelInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list, | |||
| const std::string &default_format, bool use_same_format, | |||
| std::vector<std::string> *graph_input_format, | |||
| std::vector<TypeId> *graph_input_type) { | |||
| MS_EXCEPTION_IF_NULL(graph_input_format); | |||
| MS_EXCEPTION_IF_NULL(graph_input_type); | |||
| // We set same format to all inputs of graph kernel subgraph, and process this latter. | |||
| // We set dtype to inputs of graph kernel subgraph same as infer dtypes. | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; | |||
| MS_EXCEPTION_IF_NULL(input_kernel_node); | |||
| if (use_same_format) { | |||
| bool can_convert = true; | |||
| if (default_format == kOpFormat_FRAC_NZ) { | |||
| auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); | |||
| if (!CanConvertDefaultShapeToNZ(infer_shape)) { | |||
| MS_LOG(WARNING) << "Shape can't be converted to frac nz shape, so use default format instead"; | |||
| can_convert = false; | |||
| } | |||
| } | |||
| if (can_convert) { | |||
| graph_input_format->push_back(default_format); | |||
| } else { | |||
| graph_input_format->push_back(kOpFormat_DEFAULT); | |||
| } | |||
| graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i)); | |||
| continue; | |||
| } | |||
| if (!input_kernel_node->isa<Parameter>()) { | |||
| // subgraph parameter from output of other nodes. | |||
| graph_input_format->push_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i)); | |||
| graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i)); | |||
| continue; | |||
| } | |||
| auto para = input_kernel_node->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(para); | |||
| if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) { | |||
| // parameter already selected. | |||
| graph_input_format->push_back(AnfAlgo::GetOutputFormat(para, 0)); | |||
| graph_input_type->push_back(AnfAlgo::GetOutputDeviceDataType(para, 0)); | |||
| continue; | |||
| } | |||
| // weight parameter. | |||
| graph_input_format->push_back(default_format); | |||
| graph_input_type->push_back(AnfAlgo::GetOutputInferDataType(input_kernel_node, 0)); | |||
| } | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| std::vector<std::string> outputs_format = {(*graph_input_format)[i]}; | |||
| std::vector<TypeId> outputs_device_type = {(*graph_input_type)[i]}; | |||
| builder.SetOutputsFormat(outputs_format); | |||
| builder.SetOutputsDeviceType(outputs_device_type); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get()); | |||
| } | |||
| } | |||
| void UpdateEquivFormat(const std::vector<std::pair<AnfNodePtr, size_t>> &output_index, | |||
| const std::vector<AnfNodePtr> &node_list, const FuncGraphPtr &func_graph, | |||
| const FuncGraphManagerPtr &mng) { | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| for (size_t i = 0; i < node_list.size(); ++i) { | |||
| // select nodes in subgraph. | |||
| auto anf_node = node_list[i]; | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| cnode->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| SelectKernelInfo(cnode, KernelType::AKG_KERNEL); | |||
| // Update ReduceSum | |||
| if (!IsPrimitiveCNode(cnode, prim::kPrimReduceSum)) { | |||
| continue; | |||
| } | |||
| UpdateFracNZReduceOp(cnode); | |||
| // If ReduceSum's output is 1d and not Default format, convert it to Default format | |||
| auto out_format = AnfAlgo::GetOutputFormat(cnode, 0); | |||
| if (out_format == kOpFormat_DEFAULT || !AnfAlgo::HasNodeAttr(kAttrOutputDefault, cnode)) { | |||
| continue; | |||
| } | |||
| auto infer_shape = AnfAlgo::GetOutputInferShape(cnode, 0); | |||
| // Insert EquivFormat node, then select kernel info again | |||
| std::vector<AnfNodePtr> trans_inputs; | |||
| trans_inputs.push_back(NewValueNode(prim::kPrimEquivFormat)); | |||
| trans_inputs.push_back(cnode); | |||
| CNodePtr trans_node = func_graph->NewCNode(trans_inputs); | |||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0)}, | |||
| {AnfAlgo::GetOutputInferShape(cnode, 0)}, trans_node.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue<std::vector<std::string>>({"x"}), trans_node); | |||
| if (trans_node->kernel_info() == nullptr) { | |||
| trans_node->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| } | |||
| SelectKernelInfo(trans_node, KernelType::AKG_KERNEL); | |||
| mng->Replace(cnode, trans_node); | |||
| } | |||
| } | |||
| void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &node_list, | |||
| const std::vector<AnfNodePtr> &input_list, const FuncGraphManagerPtr &mng, | |||
| const std::string &default_format, std::vector<std::string> *graph_input_format, | |||
| std::vector<TypeId> *graph_input_type) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| MS_EXCEPTION_IF_NULL(graph_input_format); | |||
| MS_EXCEPTION_IF_NULL(graph_input_type); | |||
| // update graph input format and dtype use inner ops. | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (graph_input_format->size() != input_num) { | |||
| MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString() | |||
| << "], [%" << graph_input_format->size() << "] != [%" << input_num << "]"; | |||
| } | |||
| std::vector<bool> need_update(input_num, false); | |||
| auto &node_users = mng->node_users(); | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| auto &input = input_list[i]; | |||
| auto iter = node_users.find(input); | |||
| if (iter == node_users.end() || iter->second.empty()) { | |||
| continue; | |||
| } | |||
| for (auto &node_user : iter->second) { | |||
| if (node_user.first->kernel_info() == nullptr || | |||
| node_user.first->kernel_info()->select_kernel_build_info() == nullptr) { | |||
| // maybe not a real kernel. | |||
| continue; | |||
| } | |||
| auto user_format = AnfAlgo::GetInputFormat(node_user.first, IntToSize(node_user.second - 1)); | |||
| if (user_format != (*graph_input_format)[i]) { | |||
| MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" | |||
| << kernel_node->DebugString() | |||
| << "] selected different format. we use defult: " << default_format; | |||
| (*graph_input_format)[i] = default_format; | |||
| need_update[i] = true; | |||
| } | |||
| if (kernel_node->input(i + 1)->isa<Parameter>()) { | |||
| auto user_dtype = AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1)); | |||
| if (user_dtype != (*graph_input_type)[i]) { | |||
| TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0); | |||
| MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" | |||
| << kernel_node->DebugString() | |||
| << "] selected different dtype. we use default: " << TypeIdLabel(default_dtype); | |||
| (*graph_input_type)[i] = default_dtype; | |||
| need_update[i] = true; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| if (!need_update[i]) { | |||
| continue; | |||
| } | |||
| need_update[i] = false; | |||
| MS_LOG(DEBUG) << "Update input format: " << i << " of: [" << kernel_node->DebugString() | |||
| << "] to: " << (*graph_input_format)[i]; | |||
| MS_LOG(DEBUG) << "Update input dtype: " << i << " of: [" << kernel_node->DebugString() | |||
| << "] to: " << TypeIdLabel((*graph_input_type)[i]); | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| std::vector<std::string> outputs_format = {(*graph_input_format)[i]}; | |||
| std::vector<TypeId> outputs_device_type = {(*graph_input_type)[i]}; | |||
| builder.SetOutputsFormat(outputs_format); | |||
| builder.SetOutputsDeviceType(outputs_device_type); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get()); | |||
| } | |||
| ResetKernelBuildInfo(kernel_node); | |||
| // select nodes in subgraph again. | |||
| for (size_t i = 0; i < node_list.size(); ++i) { | |||
| auto anf_node = node_list[i]; | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| size_t cnode_input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| for (size_t j = 0; j < cnode_input_num; ++j) { | |||
| auto input_node = cnode->input(j + 1); | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (!IsValueNode<tensor::Tensor>(input_node)) { | |||
| continue; | |||
| } | |||
| // reset format and dtype of const tensor. | |||
| builder.SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT}); | |||
| builder.SetOutputsDeviceType(std::vector<TypeId>{kTypeUnknown}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_node.get()); | |||
| } | |||
| SelectKernelInfo(node_list[i]->cast<CNodePtr>(), KernelType::AKG_KERNEL); | |||
| } | |||
| } | |||
| void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector<std::pair<AnfNodePtr, size_t>> &output_index, | |||
| const std::vector<std::string> &graph_input_format, | |||
| const std::vector<TypeId> &graph_input_type) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::vector<std::string> graph_output_format; | |||
| std::vector<TypeId> graph_output_type; | |||
| for (size_t i = 0; i < output_index.size(); ++i) { | |||
| auto const &output = output_index[i]; | |||
| graph_output_format.push_back(AnfAlgo::GetOutputFormat(output.first, output.second)); | |||
| TypeId output_type(kTypeUnknown); | |||
| if (output.first->isa<CNode>()) { | |||
| output_type = AnfAlgo::GetCNodeOutputPrecision(output.first); | |||
| } | |||
| if (output_type == kTypeUnknown) { | |||
| output_type = AnfAlgo::GetOutputDeviceDataType(output.first, output.second); | |||
| } | |||
| graph_output_type.push_back(output_type); | |||
| } | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder; | |||
| graph_info_builder.SetInputsFormat(graph_input_format); | |||
| graph_info_builder.SetInputsDeviceType(graph_input_type); | |||
| graph_info_builder.SetOutputsFormat(graph_output_format); | |||
| graph_info_builder.SetOutputsDeviceType(graph_output_type); | |||
| graph_info_builder.SetProcessor(kernel::Processor::AICORE); | |||
| graph_info_builder.SetKernelType(KernelType::AKG_KERNEL); | |||
| graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE); | |||
| auto graph_selected_info = graph_info_builder.Build(); | |||
| MS_EXCEPTION_IF_NULL(graph_selected_info); | |||
| AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get()); | |||
| SetTensorDeviceInfo(*graph_selected_info, kernel_node); | |||
| } | |||
| void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| // collect input info of funcgraph | |||
| std::vector<AnfNodePtr> node_list; | |||
| std::vector<AnfNodePtr> input_list; | |||
| std::vector<AnfNodePtr> output_list; | |||
| kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); | |||
| if (input_list.size() != kernel_node->inputs().size() - 1) { | |||
| MS_EXCEPTION(ArgumentError) << "Input num of funcgraph[" << func_graph->ToString() << "] not equal input of cnode[" | |||
| << kernel_node->DebugString() << "], [%" << input_list.size() << "] != [" | |||
| << kernel_node->inputs().size() << "]"; | |||
| } | |||
| std::string default_format; | |||
| bool use_same_format = true; | |||
| GetDefaultFormat(kernel_node, &default_format, &use_same_format); | |||
| MS_LOG(DEBUG) << "GraphKernel[" << func_graph->ToString() << "] use same input format[" << default_format | |||
| << "] for ParameterWeight."; | |||
| std::vector<std::string> graph_input_format; | |||
| std::vector<TypeId> graph_input_type; | |||
| UpdateGraphKernelInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format, | |||
| &graph_input_type); | |||
| auto mng = func_graph->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(func_graph, true); | |||
| } | |||
| auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list); | |||
| UpdateEquivFormat(output_index, node_list, func_graph, mng); | |||
| node_list.clear(); | |||
| input_list.clear(); | |||
| output_list.clear(); | |||
| kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); | |||
| // update graph input format and dtype use inner ops. | |||
| UpdateFormatsAndDtypes(kernel_node, node_list, input_list, mng, default_format, &graph_input_format, | |||
| &graph_input_type); | |||
| // set fix_precision for kernel when the me prim has fix_precision attr | |||
| UpdateKernelInfo(node_list); | |||
| output_index = kernel::GetOutputIndex(node_list, input_list, output_list); | |||
| SetGraphKernelInfo(kernel_node, output_index, graph_input_format, graph_input_type); | |||
| } | |||
| } // namespace ascend | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -24,7 +24,7 @@ namespace device { | |||
| namespace ascend { | |||
| void GraphDescReporter::ReportData() { | |||
| for (const auto &node : cnode_list_) { | |||
| if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AUTO_DIFF_KERNEL) { | |||
| if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) { | |||
| MS_LOG(WARNING) << "Skip non tbe kernel"; | |||
| continue; | |||
| } | |||
| @@ -31,7 +31,7 @@ void TaskDescReporter::ReportData() { | |||
| size_t task_index = 0; | |||
| for (const auto &node : cnode_list_) { | |||
| if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AUTO_DIFF_KERNEL) { | |||
| if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) { | |||
| MS_LOG(WARNING) << "Skip non tbe kernel"; | |||
| ++task_index; | |||
| continue; | |||
| @@ -43,7 +43,37 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve | |||
| void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) { | |||
| MS_EXCEPTION_IF_NULL(anf_node_ptr); | |||
| if (anf_node_ptr->inputs().size() != 2) { | |||
| MS_LOG(EXCEPTION) << "atomic Addr clean Node Input nodes not equal 2."; | |||
| // akg process | |||
| // set atomic clean addr | |||
| if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, anf_node_ptr)) { | |||
| auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(anf_node_ptr, kAttrAutomicOutputIndexs); | |||
| auto graph = anf_node_ptr->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto node_users = manager->node_users(); | |||
| if (node_users[anf_node_ptr].empty()) { | |||
| MS_LOG(EXCEPTION) << "Node users of " << anf_node_ptr->ToString() << " is empty."; | |||
| } | |||
| auto depend_node = node_users[anf_node_ptr].pop().first; | |||
| if (!IsPrimitiveCNode(depend_node, prim::kPrimDepend)) { | |||
| MS_LOG(EXCEPTION) << "Checking Depend node failed"; | |||
| } | |||
| if (node_users[depend_node].empty()) { | |||
| MS_LOG(EXCEPTION) << "Node users of " << depend_node->ToString() << " is empty."; | |||
| } | |||
| auto post_node = node_users[depend_node].pop().first; | |||
| for (auto index : clean_output_indexs) { | |||
| auto device_address = AnfAlgo::GetOutputAddr(post_node, index); | |||
| kernel::AddressPtr input = std::make_shared<kernel::Address>(); | |||
| input->addr = device_address->ptr_; | |||
| MS_EXCEPTION_IF_NULL(input->addr); | |||
| input->size = device_address->size_; | |||
| kernel_inputs->push_back(input); | |||
| } | |||
| MS_LOG(DEBUG) << "AtomicAddClean clean output size: " << clean_output_indexs.size(); | |||
| } | |||
| return; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]); | |||
| auto pre_node = (anf_node_ptr->inputs()[1])->cast<CNodePtr>(); | |||
| @@ -59,7 +89,7 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP | |||
| input->size = device_address->size_; | |||
| kernel_inputs->push_back(input); | |||
| } | |||
| MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); | |||
| MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); | |||
| } | |||
| // set clean workspace address | |||
| if (AnfAlgo::HasNodeAttr(kAttrAutomicWorkspaceSize, pre_node)) { | |||
| @@ -16,7 +16,7 @@ | |||
| #include "device/gpu/gpu_kernel_build.h" | |||
| #include <string> | |||
| #include "kernel/kernel.h" | |||
| #include "kernel/akg/akgkernelbuild.h" | |||
| #include "kernel/akg/akg_kernel_build.h" | |||
| #include "kernel/akg/gpu/akg_gpu_kernel_build.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| #include "operator/ops.h" | |||
| @@ -37,7 +37,7 @@ void GpuBuild(const KernelGraphPtr &kernel_graph) { | |||
| continue; | |||
| } | |||
| if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AUTO_DIFF_KERNEL) { | |||
| if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AKG_KERNEL) { | |||
| auto gpu_kernel_ptr = kernel::AkgGpuKernelBuild(kernel); | |||
| if (!gpu_kernel_ptr) { | |||
| MS_LOG(EXCEPTION) << "Build akg kernel op[" << kernel_name << "] failed"; | |||
| @@ -184,7 +184,7 @@ void SetKernelInfo(const CNodePtr &kernel_node) { | |||
| if (!result) { | |||
| result = SelectAkgKernel(kernel_node, builder->Build()); | |||
| kernel_type = AUTO_DIFF_KERNEL; | |||
| kernel_type = AKG_KERNEL; | |||
| } | |||
| if (!result) { | |||
| @@ -26,6 +26,8 @@ | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive_base.h" | |||
| #include "operator/ops.h" | |||
| namespace mindspore { | |||
| // namespace to support intermediate representation definition | |||
| CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph) | |||
| @@ -106,10 +108,14 @@ std::string ValueNode::fullname_with_scope() { | |||
| bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode != nullptr) { | |||
| if (cnode == nullptr) { | |||
| return false; | |||
| } | |||
| if (value != nullptr) { | |||
| return cnode->IsApply(value); | |||
| } | |||
| return false; | |||
| const auto &prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| return prim != nullptr; | |||
| } | |||
| PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) { | |||
| @@ -124,6 +124,7 @@ class AnfNode : public Base { | |||
| const KernelInfoDevice *kernel_info() const { return kernel_info_.get(); } | |||
| KernelInfoDevice *kernel_info() { return kernel_info_.get(); } | |||
| const KernelInfoDevicePtr &kernel_info_ptr() { return kernel_info_; } | |||
| void set_kernel_info(const KernelInfoDevicePtr &kernel_info) { kernel_info_ = kernel_info; } | |||
| AbstractBasePtr abstract() const { return abstract_; } | |||
| @@ -395,9 +396,9 @@ static S GetValue(const ValuePtr &value) { | |||
| std::string GetCNodeFuncName(CNodePtr cnode); | |||
| // used to check whether an AnfNode is a cnode with a kind of Primitive as first input | |||
| bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value); | |||
| bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value = nullptr); | |||
| // used to check whether an AnfNode is a cnode with a Primitive as first input | |||
| // used to get PrimitivePtr from a cnode first input | |||
| PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node); | |||
| // used to check whether an AnfNode is a valuenode having some Primitive value | |||
| @@ -70,7 +70,7 @@ std::string CNode::fullname_with_scope() { | |||
| } | |||
| fullname_with_scope_ = name; | |||
| } else { | |||
| // cnode input 0 should be primitive ptr | |||
| // cnode input 0 should be primitive ptr or funcgraph ptr | |||
| auto value_ptr = input(0)->cast<ValueNodePtr>(); | |||
| if (value_ptr == nullptr) { | |||
| MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << "."; | |||
| @@ -84,11 +84,23 @@ std::string CNode::fullname_with_scope() { | |||
| return fullname_with_scope_; | |||
| } | |||
| PrimitivePtr prim = GetValue<PrimitivePtr>(input_value); | |||
| auto prim = input_value->cast<PrimitivePtr>(); | |||
| MS_EXCEPTION_IF_NULL(scope()); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| fullname_with_scope_ = | |||
| scope()->name() + "/" + prim->name() + "-op" + id_generator::get_id(shared_from_base<CNode>()); | |||
| fullname_with_scope_ = scope()->name() + "/"; | |||
| if (prim != nullptr) { | |||
| fullname_with_scope_ += prim->name(); | |||
| } else { | |||
| auto func_graph = input_value->cast<FuncGraphPtr>(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto fg_flag = func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||
| if (fg_flag != nullptr) { | |||
| auto fg_name = GetValue<std::string>(fg_flag); | |||
| fullname_with_scope_ += "GraphKernel_" + fg_name; | |||
| } else { | |||
| fullname_with_scope_ += func_graph->ToString(); | |||
| } | |||
| } | |||
| fullname_with_scope_ += "-op" + id_generator::get_id(shared_from_base<CNode>()); | |||
| } | |||
| return fullname_with_scope_; | |||
| @@ -77,9 +77,9 @@ class Bool : public Number { | |||
| TypeId generic_type_id() const override { return kNumberTypeBool; } | |||
| TypePtr DeepCopy() const override { return std::make_shared<Bool>(); } | |||
| std::string ToString() const override { return "Bool_"; } | |||
| std::string ToReprString() const override { return "bool_"; } | |||
| std::string DumpText() const override { return "Bool_"; } | |||
| std::string ToString() const override { return "Bool"; } | |||
| std::string ToReprString() const override { return "bool"; } | |||
| std::string DumpText() const override { return "Bool"; } | |||
| }; | |||
| // Int | |||
| @@ -34,7 +34,7 @@ namespace mindspore { | |||
| * Methods of Graph | |||
| */ | |||
| FuncGraph::FuncGraph() | |||
| : flags_(), | |||
| : attrs_(), | |||
| transforms_(), | |||
| parameter_default_value_(), | |||
| seen_(0), | |||
| @@ -95,13 +95,27 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { | |||
| return p; | |||
| } | |||
| bool FuncGraph::has_flag(const std::string &flag) { | |||
| if (flags_.count(flag)) { | |||
| return flags_[flag]; | |||
| bool FuncGraph::has_flag(const std::string &key) { | |||
| auto iter = attrs_.find(key); | |||
| if (iter != attrs_.cend()) { | |||
| if (iter->second->isa<BoolImm>()) { | |||
| return GetValue<bool>(iter->second); | |||
| } | |||
| MS_LOG(WARNING) << "key " << key << " is not a flag, please use has_attr function."; | |||
| } | |||
| return false; | |||
| } | |||
| bool FuncGraph::has_attr(const std::string &key) { | |||
| auto iter = attrs_.find(key); | |||
| return !(iter == attrs_.cend()); | |||
| } | |||
| ValuePtr FuncGraph::get_attr(const std::string &key) { | |||
| auto iter = attrs_.find(key); | |||
| return iter == attrs_.cend() ? nullptr : iter->second; | |||
| } | |||
| CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | |||
| CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>()); | |||
| if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | |||
| @@ -74,6 +74,7 @@ using FuncGraphMap = OrderedMap<FuncGraphPtr, int>; | |||
| const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; | |||
| const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; | |||
| const char FUNC_GRAPH_FLAG_CORE[] = "core"; | |||
| const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel"; | |||
| const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param"; | |||
| namespace abstract { | |||
| @@ -195,10 +196,19 @@ class FuncGraph : public FuncGraphBase { | |||
| void set_is_generate(bool generated) { is_generated_ = generated; } | |||
| bool is_generated() const { return is_generated_; } | |||
| bool has_flag(const std::string &flag); | |||
| std::unordered_map<std::string, bool> &flags() { return flags_; } | |||
| void set_flags(const std::unordered_map<std::string, bool> &flags) { flags_ = flags; } | |||
| void set_flags(const std::string &key, const bool value) { flags_[key] = value; } | |||
| std::unordered_map<std::string, ValuePtr> &attrs() { return attrs_; } | |||
| void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) { | |||
| for (auto &attr : attrs) { | |||
| attrs_[attr.first] = attr.second; | |||
| } | |||
| } | |||
| bool has_flag(const std::string &key); | |||
| void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); } | |||
| void erase_flag(const std::string &key) { (void)attrs_.erase(key); } | |||
| bool has_attr(const std::string &key); | |||
| ValuePtr get_attr(const std::string &key); | |||
| void set_attr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; } | |||
| std::unordered_map<std::string, FuncGraphTransform> &transforms() { return transforms_; } | |||
| void set_transforms(const std::unordered_map<std::string, FuncGraphTransform> &transforms) { | |||
| @@ -317,7 +327,7 @@ class FuncGraph : public FuncGraphBase { | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; } | |||
| std::unordered_map<std::string, bool> flags_; | |||
| std::unordered_map<std::string, ValuePtr> attrs_; | |||
| std::unordered_map<std::string, FuncGraphTransform> transforms_; | |||
| // parameter default value | |||
| std::map<std::string, AnfNodePtr> parameter_default_value_; | |||
| @@ -90,6 +90,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { | |||
| new_node->set_abstract(old_node->abstract()); | |||
| ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); | |||
| new_node->set_scope(scope); | |||
| new_node->set_kernel_info(old_node->kernel_info_ptr()); | |||
| repl_node_[old_node] = new_node; | |||
| nodes_.emplace_back(old_node, new_node); | |||
| TraceManager::EndTrace(); | |||
| @@ -211,7 +212,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons | |||
| MS_EXCEPTION_IF_NULL(target_func_graph); | |||
| TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); | |||
| *target_func_graph = std::make_shared<FuncGraph>(); | |||
| (*target_func_graph)->set_flags(func_graph->flags()); | |||
| (*target_func_graph)->set_attrs(func_graph->attrs()); | |||
| (*target_func_graph)->set_transforms(func_graph->transforms()); | |||
| (*target_func_graph)->set_has_vararg(func_graph->has_vararg()); | |||
| (*target_func_graph)->set_has_kwarg(func_graph->has_kwarg()); | |||
| @@ -636,9 +637,14 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP | |||
| if (MsContext::GetInstance()->is_multi_graph_sink()) { | |||
| if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { | |||
| new_func_graph->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| } | |||
| } | |||
| if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||
| new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||
| } | |||
| return new_func_graph; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -399,8 +399,8 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() { | |||
| depend_inputs.push_back(*iter); | |||
| } | |||
| } | |||
| set_flags(GRAPH_FLAG_HAS_EFFECT, false); | |||
| set_flags(GRAPH_FLAG_EFFECT_PATIAL_ORDER, true); | |||
| set_flag(GRAPH_FLAG_HAS_EFFECT, false); | |||
| set_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER, true); | |||
| if (!depend_inputs.empty()) { | |||
| SetEffectDepends(depend_inputs); | |||
| } | |||
| @@ -9,6 +9,10 @@ if (ENABLE_D) | |||
| file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "kernel_query.cc" | |||
| "kernel_fusion.cc" | |||
| "akg/ascend/*.cc" | |||
| "akg/akg_kernel_build.cc" | |||
| "akg/akg_kernel_attrs_process.cc" | |||
| "akg/akg_kernel_metadata.cc" | |||
| "tbe/*.cc" | |||
| "aicpu/*.cc" | |||
| "rts/*.cc" | |||
| @@ -33,7 +37,7 @@ if (ENABLE_GPU) | |||
| file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "gpu/*.cu" | |||
| "akg/gpu/*.cc" | |||
| "akg/akgkernelbuild.cc" | |||
| "akg/akg_kernel_build.cc" | |||
| "akg/akg_kernel_attrs_process.cc" | |||
| ) | |||
| @@ -24,7 +24,7 @@ | |||
| #include <map> | |||
| #include "device/kernel_runtime.h" | |||
| #include "kernel/aicpu/aicpu_kernel_mod.h" | |||
| #include "kernel/akg/akgkernelbuild.h" | |||
| #include "kernel/akg/akg_kernel_build.h" | |||
| #include "proto/tensor.pb.h" | |||
| #include "proto/tensor_shape.pb.h" | |||
| #include "proto/attr.pb.h" | |||
| @@ -79,6 +79,10 @@ void SetAkgAttrsForCast(const AnfNodePtr &anf_node) { | |||
| dst_type = "float32"; | |||
| } else if (output_type == kFloat16->type_id()) { | |||
| dst_type = "float16"; | |||
| } else if (output_type == kInt32->type_id()) { | |||
| dst_type = "int32"; | |||
| } else { | |||
| MS_LOG(WARNING) << "Unknown cast_to type: " << TypeIdToType(output_type)->ToString(); | |||
| } | |||
| AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); | |||
| } | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/akg/akgkernelbuild.h" | |||
| #include "kernel/akg/akg_kernel_build.h" | |||
| #include <Python.h> | |||
| #include <sys/types.h> | |||
| #include <signal.h> | |||
| @@ -43,7 +43,9 @@ namespace kernel { | |||
| constexpr int ME_MAX_KERNEL_NAME_LENGTH = 200; | |||
| constexpr int32_t ARGS_SIZE = 1; | |||
| constexpr auto kCompileWithJsonFunc = "compilewithjson"; | |||
| // json key | |||
| constexpr auto kOpDesc = "op_desc"; | |||
| constexpr auto kInputDesc = "input_desc"; | |||
| constexpr auto kShape = "shape"; | |||
| constexpr auto kDataType = "data_type"; | |||
| @@ -51,13 +53,24 @@ constexpr auto kOutputDesc = "output_desc"; | |||
| constexpr auto kName = "name"; | |||
| constexpr auto kTensorName = "tensor_name"; | |||
| constexpr auto kValue = "value"; | |||
| constexpr auto KInpputNames = "input_names"; | |||
| constexpr auto KDynInputSizes = "dyn_input_sizes"; | |||
| constexpr auto KInputNames = "input_names"; | |||
| constexpr auto KInput = "input"; | |||
| constexpr auto KDtype = "dtype"; | |||
| int AkgKernelBuild::op_cnt_ = 0; | |||
| std::mutex AkgKernelBuild::op_cnt_mtx_; | |||
| namespace { | |||
| template <typename T> | |||
| std::string Vector2Str(const std::vector<T> &inputs) { | |||
| if (!inputs.empty()) { | |||
| std::ostringstream oss; | |||
| (void)std::copy(inputs.begin(), inputs.end() - 1, std::ostream_iterator<T>(oss, ", ")); | |||
| oss << inputs.back(); | |||
| return oss.str(); | |||
| } | |||
| return ""; | |||
| } | |||
| } // namespace | |||
| std::string PyObjectToStr(PyObject *const PyObj) { | |||
| std::string AkgKernelBuild::PyObjectToStr(PyObject *const PyObj) { | |||
| char *pChar = nullptr; | |||
| std::string str_res; | |||
| if (PyObj == nullptr) { | |||
| @@ -76,6 +89,72 @@ std::string PyObjectToStr(PyObject *const PyObj) { | |||
| return str_res; | |||
| } | |||
| std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag, | |||
| const std::pair<size_t, size_t> &position) { | |||
| if (node_json.count(tag) == 0) { | |||
| MS_LOG(ERROR) << "Node [" << node_json.dump() << "] has no key [" << tag << "]."; | |||
| return ""; | |||
| } | |||
| auto const &tag_desc = node_json[tag]; | |||
| nlohmann::json first_index; | |||
| if (tag == kOutputDesc) { | |||
| first_index = tag_desc; | |||
| } else if (!tag_desc.is_array() || tag_desc.size() <= position.first) { | |||
| MS_LOG(ERROR) << "Node [" << tag_desc.dump() << "] has no enough value [" << position.first << "]."; | |||
| return ""; | |||
| } else { | |||
| first_index = tag_desc[position.first]; | |||
| } | |||
| if (!first_index.is_array() || first_index.size() <= position.second) { | |||
| MS_LOG(ERROR) << "Node [" << first_index.dump() << "] has no enough value [" << position.second << "]."; | |||
| return ""; | |||
| } | |||
| auto const &second_index = first_index[position.second]; | |||
| if (second_index.count(kTensorName) == 0) { | |||
| MS_LOG(ERROR) << "Node [" << second_index.dump() << "] has no key [" << kTensorName << "]."; | |||
| return ""; | |||
| } | |||
| return second_index[kTensorName]; | |||
| } | |||
| void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position, | |||
| nlohmann::json *const node_json) { | |||
| MS_EXCEPTION_IF_NULL(node_json); | |||
| if (node_json->count(tag) == 0) { | |||
| MS_LOG(ERROR) << "Node [" << node_json->dump() << "] has no key [" << tag << "]."; | |||
| return; | |||
| } | |||
| nlohmann::json *tag_desc = &((*node_json)[tag]); | |||
| nlohmann::json *first_index; | |||
| if (tag == kOutputDesc) { | |||
| first_index = tag_desc; | |||
| } else if (!tag_desc->is_array() || tag_desc->size() <= position.first) { | |||
| MS_LOG(ERROR) << "Node [" << tag_desc->dump() << "] has no enough value [" << position.first << "]."; | |||
| return; | |||
| } else { | |||
| first_index = &((*tag_desc)[position.first]); | |||
| } | |||
| if (!first_index->is_array() || first_index->size() <= position.second) { | |||
| MS_LOG(ERROR) << "Node [" << first_index->dump() << "] has no enough value [" << position.second << "]."; | |||
| return; | |||
| } | |||
| nlohmann::json *second_index = &((*first_index)[position.second]); | |||
| if (second_index->count(kTensorName) == 0) { | |||
| MS_LOG(ERROR) << "Node [" << second_index->dump() << "] has no key [" << kTensorName << "]."; | |||
| return; | |||
| } | |||
| (*second_index)[kTensorName] = new_name; | |||
| return; | |||
| } | |||
| int AkgKernelBuild::op_cnt_ = 0; | |||
| std::mutex AkgKernelBuild::op_cnt_mtx_; | |||
| std::string AkgKernelBuild::GetProcessor(const AnfNodePtr &anf_node) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| std::string device; | |||
| @@ -187,10 +266,7 @@ bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::j | |||
| for (size_t input_i = 0; input_i < input_tensor_num; input_i++) { | |||
| // dtype : float16 | |||
| auto type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index); | |||
| TypePtr type_ptr = TypeIdToType(type_id); | |||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||
| std::string dtype = type_ptr->ToString(); | |||
| dtype = Dtype2String(dtype); | |||
| std::string dtype = TypeId2String(type_id); | |||
| if (dtype.empty()) { | |||
| MS_LOG(ERROR) << "Op [" << op_name << "] input [" << input_i << "] data type is null. "; | |||
| return false; | |||
| @@ -198,13 +274,23 @@ bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::j | |||
| nlohmann::json input_desc_json; | |||
| input_desc_json[kDataType] = dtype; | |||
| input_desc_json[kName] = op_input_name; | |||
| input_desc_json[kTensorName] = | |||
| op_input_name + "_" + std::to_string(real_input_index) + "_" + std::to_string(input_i); | |||
| input_desc_json[kShape] = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index); | |||
| input_desc_json[kTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); | |||
| auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index); | |||
| if (GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { | |||
| MS_LOG(WARNING) << "we take input[" << real_input_index << "] of [" << anf_node->DebugString(2) | |||
| << "] as const tensor, shape: [" << Vector2Str(input_shape) | |||
| << "], value: " << input_desc_json[kValue]; | |||
| input_shape.clear(); | |||
| } | |||
| if (input_shape.empty()) { | |||
| input_shape.push_back(1); | |||
| } | |||
| input_desc_json[kShape] = input_shape; | |||
| input_list.emplace_back(input_desc_json); | |||
| real_input_index++; | |||
| } | |||
| inputs_json->emplace_back(input_list); | |||
| real_input_index++; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -220,10 +306,7 @@ bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann:: | |||
| for (size_t i = 0; i < output_tensor_num; i++) { | |||
| nlohmann::json output_json; | |||
| auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, i); | |||
| TypePtr type_ptr = TypeIdToType(type_id); | |||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||
| std::string dtype = type_ptr->ToString(); | |||
| dtype = Dtype2String(dtype); | |||
| std::string dtype = TypeId2String(type_id); | |||
| if (dtype.empty()) { | |||
| MS_LOG(ERROR) << "Op [" << op_name << "] output [" << i << "] data type is null. "; | |||
| return false; | |||
| @@ -232,7 +315,7 @@ bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann:: | |||
| std::string output_name = outputs[i]->name(); | |||
| output_json[kDataType] = dtype; | |||
| output_json[kName] = output_name; | |||
| output_json[kTensorName] = output_name + "_" + std::to_string(i); | |||
| output_json[kTensorName] = "output_" + std::to_string(i) + "_" + std::to_string(GetOutputTensorIdxInc()); | |||
| output_json[kShape] = AnfAlgo::GetOutputDeviceShape(anf_node, i); | |||
| outputs_json->push_back(output_json); | |||
| } | |||
| @@ -358,15 +441,14 @@ bool AkgKernelBuild::GenerateSingleKernelJson(const AnfNodePtr &anf_node, const | |||
| MS_EXCEPTION_IF_NULL(op_info_ptr); | |||
| // get basic params from currentNodeOpDesc | |||
| (*node_json)["platform"] = "AKG"; | |||
| (*node_json)[kName] = op_name; | |||
| (*node_json)["fusion_type"] = AnfAlgo::GetFusionType(anf_node); | |||
| (*node_json)["impl_path"] = op_info_ptr->impl_path(); | |||
| (*node_json)["process"] = AkgKernelBuild::GetProcessor(anf_node); | |||
| (*node_json)["composite"] = false; | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| ValuePtr input_names_v = primitive->GetAttr(KInpputNames); | |||
| ValuePtr input_names_v = primitive->GetAttr(KInputNames); | |||
| if (input_names_v == nullptr) { | |||
| MS_LOG(ERROR) << "ApplyKernel has no input_names, op[" << op_name << "]."; | |||
| return false; | |||
| @@ -465,12 +547,12 @@ KernelPackPtr AkgKernelBuild::OpBuild(const std::string &node_json, const AnfNod | |||
| (void)alarm(0); | |||
| if (pRes == nullptr) { | |||
| MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" | |||
| << PyObjectToStr(pArg) << ")."; | |||
| << AkgKernelBuild::PyObjectToStr(pArg) << ")."; | |||
| return nullptr; | |||
| } | |||
| if (PyObject_IsTrue(pRes) != 1) { | |||
| MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" | |||
| << PyObjectToStr(pArg) << ")."; | |||
| << AkgKernelBuild::PyObjectToStr(pArg) << ")."; | |||
| return nullptr; | |||
| } | |||
| @@ -513,5 +595,29 @@ KernelPackPtr AkgKernelBuild::BuildByJson(const AnfNodePtr &anf_node, std::vecto | |||
| << "]"; | |||
| return kernel_pack; | |||
| } | |||
| size_t AkgKernelBuild::GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (input_idx + 1 >= cnode->inputs().size()) { | |||
| MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of [" | |||
| << cnode->inputs().size() - 1 << "][" << cnode->DebugString() << "]"; | |||
| } | |||
| auto input_node = cnode->input(input_idx + 1); | |||
| if (input_tensor_idx_.find(input_node) == input_tensor_idx_.end()) { | |||
| size_t index = input_tensor_idx_.size(); | |||
| input_tensor_idx_[input_node] = index; | |||
| } | |||
| return input_tensor_idx_[input_node]; | |||
| } | |||
| size_t AkgKernelBuild::GetOutputTensorIdxInc() { | |||
| size_t idx = output_tensor_idx_++; | |||
| return idx; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -32,29 +32,45 @@ namespace mindspore { | |||
| namespace kernel { | |||
| class AkgKernelBuild { | |||
| public: | |||
| AkgKernelBuild() = default; | |||
| AkgKernelBuild() { | |||
| input_tensor_idx_ = {}; | |||
| output_tensor_idx_ = 0; | |||
| } | |||
| ~AkgKernelBuild() = default; | |||
| KernelPackPtr BuildByJson(const AnfNodePtr &anf_node, std::vector<size_t> *const input_size, | |||
| std::vector<size_t> *const output_size); | |||
| static std::string GetProcessor(const AnfNodePtr &anf_node); | |||
| static std::string PyObjectToStr(PyObject *const PyObj); | |||
| private: | |||
| protected: | |||
| bool CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json); | |||
| bool CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json); | |||
| bool CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name, | |||
| const std::shared_ptr<OpInfo> &op_info, nlohmann::json *const attrs_json); | |||
| KernelPackPtr OpBuild(const std::string &node_json, const AnfNodePtr &anf_node); | |||
| int GetOpCntInc(); | |||
| size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx); | |||
| size_t GetOutputTensorIdxInc(); | |||
| bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name, | |||
| nlohmann::json *const node_json); | |||
| KernelPackPtr OpBuild(const std::string &node_json, const AnfNodePtr &anf_node); | |||
| int GetOpCntInc(); | |||
| std::string GetProcessor(const AnfNodePtr &anf_node); | |||
| static int op_cnt_; | |||
| // lock for variable fusionOpCnt in singleton mode | |||
| static std::mutex op_cnt_mtx_; | |||
| std::string json_name_; | |||
| std::string json_info_; | |||
| std::unordered_map<AnfNodePtr, size_t> input_tensor_idx_; | |||
| size_t output_tensor_idx_; | |||
| }; | |||
| bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size, | |||
| std::vector<size_t> *const output_size); | |||
| void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position, | |||
| nlohmann::json *const node_json); | |||
| std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag, | |||
| const std::pair<size_t, size_t> &position); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/akg/akg_kernel_metadata.h" | |||
| #include <memory> | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "kernel/oplib/oplib.h" | |||
| #include "kernel/common_utils.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void AkgMetadataInfo(const CNodePtr &kernel_node, | |||
| std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||
| std::string op_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| for (size_t i = 0; i < support_devices.size(); i++) { | |||
| auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); | |||
| if (op_info_ptr == nullptr) { | |||
| continue; | |||
| } | |||
| if (!ParseMetadata(kernel_node, op_info_ptr, Processor(i), kernel_info_list)) { | |||
| MS_LOG(WARNING) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "] failed."; | |||
| } else { | |||
| MS_LOG(DEBUG) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "]."; | |||
| break; | |||
| } | |||
| } | |||
| if (kernel_info_list->empty()) { | |||
| MS_LOG(WARNING) << "Akg dose not has metadata of op[" << op_name << "]."; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,31 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include "kernel/kernel_build_info.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void AkgMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ | |||
| @@ -0,0 +1,385 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/akg/ascend/akg_ascend_kernel_build.h" | |||
| #include <algorithm> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_set> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <Python.h> | |||
| #include "ir/dtype.h" | |||
| #include "ir/func_graph.h" | |||
| #include "kernel/kernel.h" | |||
| #include "kernel/common_utils.h" | |||
| #include "kernel/tbe/tbe_utils.h" | |||
| #include "kernel/akg/ascend/akg_ascend_kernel_mod.h" | |||
| #include "kernel/akg/akg_kernel_attrs_process.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| constexpr int32_t PARALLEL_ARGS_SIZE = 3; | |||
| constexpr int32_t PROCESS_NUM = 16; | |||
| constexpr int32_t TIME_OUT = 300; | |||
| constexpr auto kOpDesc = "op_desc"; | |||
| constexpr auto kShape = "shape"; | |||
| constexpr auto kDataType = "data_type"; | |||
| constexpr auto kInputDesc = "input_desc"; | |||
| constexpr auto kOutputDesc = "output_desc"; | |||
| constexpr auto kTensorName = "tensor_name"; | |||
| constexpr auto kCompileAkgKernelParallelFunc = "compile_akg_kernel_parallel"; | |||
| constexpr auto kMultiProcModule = "mindspore._extends.parallel_compile.akg_compiler.multi_process_compiler"; | |||
| bool AkgAscendKernelBuilder::CollectJson(const AnfNodePtr &anf_node) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| std::string op_name = AnfAlgo::GetCNodeName(anf_node); | |||
| MS_LOG(INFO) << "AKG start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]"; | |||
| auto it = kAkgKernelAttrsProcessMap.find(op_name); | |||
| if (it != kAkgKernelAttrsProcessMap.end()) { | |||
| it->second(anf_node); | |||
| } | |||
| MS_LOG(INFO) << "Akg start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]"; | |||
| nlohmann::json node_json; | |||
| if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) { | |||
| MS_LOG(ERROR) << "Op[" << op_name << "] create single kernel json failed."; | |||
| } | |||
| kernel_json_ = node_json.dump(); | |||
| if (!GetIOSize(node_json, &input_size_list_, &output_size_list_)) { | |||
| MS_LOG(ERROR) << "Cal mem size failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, | |||
| const std::vector<AnfNodePtr> &input_list, | |||
| const std::vector<AnfNodePtr> &output_list) { | |||
| if (anf_nodes.empty() || input_list.empty()) { | |||
| MS_LOG(ERROR) << "Invalid input size, anf_nodes [" << anf_nodes.size() << "], input_list [" << input_list.size() | |||
| << "]."; | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "anf_nodes [" << output_list.size() << "], input_list [" << anf_nodes.size() << "], output_list [" | |||
| << input_list.size() << "]."; | |||
| std::map<AnfNodePtr, nlohmann::json> node_json_map; | |||
| for (auto const &anf_node : anf_nodes) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| std::string op_name = AnfAlgo::GetCNodeName(anf_node); | |||
| if (!AnfAlgo::IsRealKernel(anf_node)) { | |||
| MS_LOG(ERROR) << "Invalid anf node to build [" << anf_node->fullname_with_scope() << "]."; | |||
| return false; | |||
| } | |||
| auto it = kAkgKernelAttrsProcessMap.find(op_name); | |||
| if (it != kAkgKernelAttrsProcessMap.end()) { | |||
| it->second(anf_node); | |||
| } | |||
| nlohmann::json node_json; | |||
| if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) { | |||
| MS_LOG(ERROR) << "Op [" << op_name << "] create single kernel json failed."; | |||
| return false; | |||
| } | |||
| // No need for composite op. | |||
| node_json.erase("id"); | |||
| node_json.erase("op"); | |||
| node_json.erase("composite"); | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| if (primitive->GetAttr("fusion") != nullptr) { | |||
| node_json["fusion"] = primitive->GetAttr("fusion")->ToString(); | |||
| } | |||
| node_json_map[anf_node] = node_json; | |||
| } | |||
| for (auto const &anf_node : anf_nodes) { | |||
| std::vector<int> dyn_input_sizes; | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { | |||
| dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes)); | |||
| } | |||
| bool is_dynamic_input = !dyn_input_sizes.empty(); | |||
| size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node); | |||
| size_t real_input_index = 0; | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1; | |||
| for (size_t j = 0; j < input_tensor_num; ++j) { | |||
| auto tmp_input = GetKernelInput(anf_node, real_input_index); | |||
| std::string tensor_name = GetTensorName(node_json_map[anf_node], kInputDesc, std::make_pair(i, j)); | |||
| if (node_json_map.find(tmp_input.first) != node_json_map.end()) { | |||
| std::string new_tensor_name = | |||
| GetTensorName(node_json_map[tmp_input.first], kOutputDesc, std::make_pair(0, tmp_input.second)); | |||
| SetTensorName(kInputDesc, new_tensor_name, std::make_pair(i, j), &(node_json_map[anf_node])); | |||
| MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of [" | |||
| << anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output [" | |||
| << new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "]."; | |||
| } else { | |||
| MS_LOG(DEBUG) << "[" << real_input_index << "] input " << tensor_name << "] of [" | |||
| << anf_node->fullname_with_scope() << "] is out input."; | |||
| } | |||
| real_input_index++; | |||
| } | |||
| } | |||
| } | |||
| nlohmann::json fused_node_json; | |||
| std::vector<nlohmann::json> node_json_desc; | |||
| std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc), | |||
| [&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; }); | |||
| fused_node_json[kOpDesc] = node_json_desc; | |||
| nlohmann::json inputs_json; | |||
| auto input_index = GetInputIndex(anf_nodes, input_list); | |||
| for (size_t i = 0; i < input_index.size(); ++i) { | |||
| auto tmp_input = input_index[i]; | |||
| auto type_id = AnfAlgo::GetInputDeviceDataType(tmp_input.first, tmp_input.second.first); | |||
| std::string dtype = TypeId2String(type_id); | |||
| nlohmann::json input_desc_json; | |||
| input_desc_json[kTensorName] = GetTensorName(node_json_map[tmp_input.first], kInputDesc, tmp_input.second); | |||
| input_desc_json[kDataType] = dtype; | |||
| input_desc_json[kShape] = AnfAlgo::GetInputDeviceShape(tmp_input.first, tmp_input.second.first); | |||
| inputs_json.emplace_back(std::vector<nlohmann::json>{input_desc_json}); | |||
| } | |||
| fused_node_json[kInputDesc] = inputs_json; | |||
| nlohmann::json outputs_json; | |||
| auto output_index = GetOutputIndex(anf_nodes, input_list, output_list); | |||
| for (size_t i = 0; i < output_index.size(); ++i) { | |||
| auto tmp_output = output_index[i]; | |||
| bool found = false; | |||
| nlohmann::json output_desc_json; | |||
| for (size_t input_i = 0; input_i < input_list.size(); ++input_i) { | |||
| if (tmp_output.first == input_list[input_i]) { | |||
| output_desc_json = inputs_json[input_i][0]; | |||
| found = true; | |||
| break; | |||
| } | |||
| } | |||
| if (!found) { | |||
| auto type_id = AnfAlgo::GetOutputDeviceDataType(tmp_output.first, tmp_output.second); | |||
| std::string dtype = TypeId2String(type_id); | |||
| output_desc_json[kTensorName] = | |||
| GetTensorName(node_json_map[tmp_output.first], kOutputDesc, std::make_pair(0, tmp_output.second)); | |||
| output_desc_json[kDataType] = dtype; | |||
| auto output_shape = AnfAlgo::GetOutputDeviceShape(tmp_output.first, tmp_output.second); | |||
| if (output_shape.empty()) { | |||
| output_shape.push_back(1); | |||
| } | |||
| output_desc_json[kShape] = output_shape; | |||
| } | |||
| outputs_json.emplace_back(output_desc_json); | |||
| } | |||
| fused_node_json[kOutputDesc] = outputs_json; | |||
| size_t hash_id = std::hash<std::string>()(fused_node_json.dump()); | |||
| json_name_ = "Fused_"; | |||
| auto fg = anf_nodes[0]->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||
| if (attr_val != nullptr) { | |||
| auto fg_attr = GetValue<std::string>(attr_val); | |||
| (void)json_name_.append(fg_attr).append("_"); | |||
| } | |||
| (void)json_name_.append(std::to_string(hash_id)); | |||
| fused_node_json["composite_graph"] = fg->ToString(); | |||
| fused_node_json["op"] = json_name_; | |||
| fused_node_json["platform"] = "AKG"; | |||
| fused_node_json["process"] = "aicore"; | |||
| fused_node_json["composite"] = true; | |||
| kernel_json_ = fused_node_json.dump(); | |||
| if (!GetIOSize(fused_node_json, &input_size_list_, &output_size_list_)) { | |||
| MS_LOG(ERROR) << "Cal mem size failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void GenParallelCompileFuncArgs(const std::vector<std::string> &kernel_jsons, PyObject **p_args) { | |||
| MS_EXCEPTION_IF_NULL(p_args); | |||
| *p_args = PyTuple_New(PARALLEL_ARGS_SIZE); | |||
| PyObject *arg1 = PyList_New(kernel_jsons.size()); | |||
| for (int i = 0; i < PyList_Size(arg1); ++i) { | |||
| PyList_SetItem(arg1, i, Py_BuildValue("s", kernel_jsons[i].c_str())); | |||
| } | |||
| PyObject *arg2 = Py_BuildValue("i", PROCESS_NUM); | |||
| PyObject *arg3 = Py_BuildValue("i", TIME_OUT); | |||
| (void)PyTuple_SetItem(*p_args, 0, arg1); | |||
| (void)PyTuple_SetItem(*p_args, 1, arg2); | |||
| (void)PyTuple_SetItem(*p_args, 2, arg3); | |||
| } | |||
| bool AkgOpParallelBuild(const std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> &build_args) { | |||
| // Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess. | |||
| std::vector<std::string> jsons; | |||
| std::unordered_set<std::string> json_name_set; | |||
| std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> repeat_nodes; | |||
| for (const auto &[builder, anf_node] : build_args) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| auto json_name = builder.json_name(); | |||
| MS_LOG(DEBUG) << "Akg start compile op: " << json_name; | |||
| auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); | |||
| if (cached_kernel_pack != nullptr) { | |||
| MS_LOG(DEBUG) << "Use cached kernel, json_name_[" << json_name << "], fullname_with_scope[" | |||
| << anf_node->fullname_with_scope() << "]."; | |||
| auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack); | |||
| kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); | |||
| kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); | |||
| AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); | |||
| continue; | |||
| } | |||
| if (json_name_set.count(json_name) != 0) { | |||
| repeat_nodes.push_back({builder, anf_node}); | |||
| continue; | |||
| } | |||
| json_name_set.insert(json_name); | |||
| auto node_json = builder.kernel_json(); | |||
| kernel::SaveJsonInfo(json_name, node_json); | |||
| jsons.push_back(node_json); | |||
| } | |||
| // No nodes need to be compiled! | |||
| if (jsons.empty()) { | |||
| return true; | |||
| } | |||
| // Try to call python method to compile nodes parallely. | |||
| PyObject *p_module = nullptr; | |||
| PyObject *p_func = nullptr; | |||
| PyObject *p_arg = nullptr; | |||
| PyObject *p_res = nullptr; | |||
| p_module = PyImport_ImportModule(kMultiProcModule); | |||
| if (p_module == nullptr) { | |||
| MS_LOG(ERROR) << "Failed to import [" << kMultiProcModule << "]."; | |||
| return false; | |||
| } | |||
| p_func = PyObject_GetAttrString(p_module, kCompileAkgKernelParallelFunc); | |||
| GenParallelCompileFuncArgs(jsons, &p_arg); | |||
| MS_LOG(DEBUG) << "Call function [" << kCompileAkgKernelParallelFunc << "], try to compile " << jsons.size() | |||
| << " Akg kernels parallelly."; | |||
| p_res = PyEval_CallObject(p_func, p_arg); | |||
| if (p_res == nullptr) { | |||
| PyErr_Print(); | |||
| MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n(" | |||
| << AkgKernelBuild::PyObjectToStr(p_arg) << ")."; | |||
| return false; | |||
| } | |||
| if (PyObject_IsTrue(p_res) != 1) { | |||
| PyErr_Print(); | |||
| MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n(" | |||
| << AkgKernelBuild::PyObjectToStr(p_arg) << ")."; | |||
| return false; | |||
| } | |||
| // All unique done here, cache them and set kernel. | |||
| for (const auto &[builder, anf_node] : build_args) { | |||
| auto json_name = builder.json_name(); | |||
| auto new_kernel_pack = tbe::TbeUtils::InsertCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); | |||
| if (new_kernel_pack == nullptr) { | |||
| MS_LOG(ERROR) << "Insert to cache failed, json_name_[" << json_name << "], fullname_with_scope[" | |||
| << anf_node->fullname_with_scope() << "]."; | |||
| return false; | |||
| } | |||
| auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(new_kernel_pack); | |||
| kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); | |||
| kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); | |||
| AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); | |||
| MS_LOG(DEBUG) << "Akg compile " << json_name << " kernel and insert cache successfully!"; | |||
| } | |||
| // Handle repeated nodes. | |||
| for (const auto &[builder, anf_node] : repeat_nodes) { | |||
| auto node_json = builder.kernel_json(); | |||
| auto json_name = builder.json_name(); | |||
| auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); | |||
| if (cached_kernel_pack == nullptr) return false; | |||
| MS_LOG(INFO) << "Use just compiled kernel, json_name_[" << json_name << "], fullname_with_scope[" | |||
| << anf_node->fullname_with_scope() << "]."; | |||
| auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack); | |||
| kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); | |||
| kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); | |||
| AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); | |||
| } | |||
| return true; | |||
| } | |||
| bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) { | |||
| std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> json_and_node; | |||
| for (const auto &anf_node : anf_nodes) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| AkgAscendKernelBuilder akg_cce_kernel_builder; | |||
| KernelPackPtr kernel_pack = nullptr; | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (AnfAlgo::IsGraphKernel(cnode)) { | |||
| auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); | |||
| auto mng = func_graph->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(func_graph, true); | |||
| func_graph->set_manager(mng); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::vector<AnfNodePtr> node_list; | |||
| std::vector<AnfNodePtr> input_list; | |||
| std::vector<AnfNodePtr> output_list; | |||
| std::string op_name = AnfAlgo::GetCNodeName(anf_node); | |||
| MS_LOG(INFO) << "Akg start compile composite op[" << op_name << "]"; | |||
| GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); | |||
| if (!akg_cce_kernel_builder.CollectFusedJson(node_list, input_list, output_list)) { | |||
| MS_EXCEPTION(UnknownError) << "Akg build failed composite op[" << op_name << "]."; | |||
| } | |||
| } else { | |||
| if (!akg_cce_kernel_builder.CollectJson(anf_node)) { | |||
| MS_EXCEPTION(UnknownError) << "Akg build failed op[" << AnfAlgo::GetCNodeName(anf_node) << "]."; | |||
| } | |||
| } | |||
| json_and_node.push_back({akg_cce_kernel_builder, anf_node}); | |||
| } | |||
| if (json_and_node.empty()) { | |||
| MS_LOG(DEBUG) << "There is no kernel needed to be compiled."; | |||
| return true; | |||
| } | |||
| return AkgOpParallelBuild(json_and_node); | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "ir/anf.h" | |||
| #include "kernel/kernel.h" | |||
| #include "kernel/akg/akg_kernel_build.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class AkgAscendKernelBuilder : public AkgKernelBuild { | |||
| public: | |||
| AkgAscendKernelBuilder() = default; | |||
| ~AkgAscendKernelBuilder() = default; | |||
| bool CollectJson(const AnfNodePtr &anf_node); | |||
| bool CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list, | |||
| const std::vector<AnfNodePtr> &output_list); | |||
| std::string json_name() const { return json_name_; } | |||
| std::string kernel_json() const { return kernel_json_; } | |||
| const std::vector<size_t> &input_size_list() const { return input_size_list_; } | |||
| const std::vector<size_t> &output_size_list() const { return output_size_list_; } | |||
| private: | |||
| std::string kernel_json_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| }; | |||
| bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ | |||
| @@ -0,0 +1,181 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/akg/ascend/akg_ascend_kernel_mod.h" | |||
| #include <algorithm> | |||
| #include <fstream> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "nlohmann/json.hpp" | |||
| #include "runtime/rt.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/convert_utils.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| using std::fstream; | |||
| using std::map; | |||
| using std::mutex; | |||
| using std::string; | |||
| using TbeTaskInfoPtr = std::shared_ptr<ge::model_runner::TbeTaskInfo>; | |||
| using tbe::KernelManager; | |||
| constexpr uint32_t DEFAULT_BLOCK_DIM = 1; | |||
| /** | |||
| * @brief infotable contain func_stub\blockdim\kernel file buffer | |||
| */ | |||
| AkgKernelMod::AkgKernelMod(const KernelPackPtr &kernel_pack) : kernel_pack_(kernel_pack) {} | |||
| void AkgKernelMod::SetInputSizeList(const std::vector<size_t> &size_list) { input_size_list_ = size_list; } | |||
| void AkgKernelMod::SetOutputSizeList(const std::vector<size_t> &size_list) { output_size_list_ = size_list; } | |||
| void AkgKernelMod::SetWorkspaceSizeList(const std::vector<size_t> &size_list) { workspace_size_list_ = size_list; } | |||
| const std::vector<size_t> &AkgKernelMod::GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &AkgKernelMod::GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &AkgKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| void DumpData(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||
| const char *dump_data = getenv("MS_KERNEL_DUMP_DATA"); | |||
| if (dump_data) { | |||
| int idx = 0; | |||
| for (const auto &x : inputs) { | |||
| std::vector<char> buf(x->size); | |||
| if (RT_ERROR_NONE != rtMemcpy(buf.data(), buf.size(), reinterpret_cast<const void *>(x->addr), x->size, | |||
| RT_MEMCPY_DEVICE_TO_HOST)) { | |||
| MS_LOG(WARNING) << "Call runtime rtMemcpy error."; | |||
| return; | |||
| } | |||
| std::string file_name("input_"); | |||
| file_name += std::to_string(idx); | |||
| std::ofstream file(file_name, std::ios::binary); | |||
| if (file.is_open()) { | |||
| (void)file.write(buf.data(), SizeToLong(buf.size())); | |||
| file.close(); | |||
| idx++; | |||
| } else { | |||
| MS_LOG(ERROR) << "Open file failed."; | |||
| return; | |||
| } | |||
| } | |||
| idx = 0; | |||
| for (const auto &x : outputs) { | |||
| std::vector<char> buf(x->size); | |||
| if (RT_ERROR_NONE != rtMemcpy(buf.data(), buf.size(), reinterpret_cast<const void *>(x->addr), x->size, | |||
| RT_MEMCPY_DEVICE_TO_HOST)) { | |||
| MS_LOG(WARNING) << "Call runtime rtMemcpy error."; | |||
| return; | |||
| } | |||
| std::string file_name("output_"); | |||
| file_name += std::to_string(idx); | |||
| std::ofstream file(file_name, std::ios::binary); | |||
| if (file.is_open()) { | |||
| (void)file.write(buf.data(), SizeToLong(buf.size())); | |||
| file.close(); | |||
| idx++; | |||
| } else { | |||
| MS_LOG(ERROR) << "Open file failed."; | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| bool AkgKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | |||
| if (stream_ptr == 0) { | |||
| MS_LOG(ERROR) << "stream_ptr should not be nullptr."; | |||
| return false; | |||
| } | |||
| if (kernel_pack_ == nullptr) { | |||
| MS_LOG(ERROR) << "kernel pack should not be nullptr."; | |||
| return false; | |||
| } | |||
| uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1. | |||
| auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim); | |||
| if (func_stub == 0) { | |||
| MS_LOG(ERROR) << "GenFuncStub failed."; | |||
| return false; | |||
| } | |||
| // pack all addresses into a vector. | |||
| std::vector<void *> runtime_args; | |||
| (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtime_args), | |||
| [](const AddressPtr &input) -> void * { return input->addr; }); | |||
| (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtime_args), | |||
| [](const AddressPtr &output) -> void * { return output->addr; }); | |||
| rtL2Ctrl_t *l2ctrl = nullptr; | |||
| auto stream = reinterpret_cast<rtStream_t *>(stream_ptr); | |||
| if (RT_ERROR_NONE != rtKernelLaunch(reinterpret_cast<void *>(func_stub), block_dim, runtime_args.data(), | |||
| SizeToUint(sizeof(void *) * runtime_args.size()), l2ctrl, stream)) { | |||
| MS_LOG(ERROR) << "Call runtime rtKernelLaunch error."; | |||
| return false; | |||
| } | |||
| DumpData(inputs, outputs); | |||
| return true; | |||
| } | |||
| std::vector<TaskInfoPtr> AkgKernelMod::GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, uint32_t stream_id) { | |||
| if (kernel_pack_ == nullptr) { | |||
| MS_LOG(EXCEPTION) << "kernel pack should not be nullptr."; | |||
| } | |||
| std::vector<uint8_t> args; | |||
| uint32_t args_size = 0; | |||
| std::vector<uint8_t> sm_desc; | |||
| void *binary = nullptr; | |||
| uint32_t binary_size = 0; | |||
| std::vector<uint8_t> meta_data; | |||
| std::vector<void *> input_data_addrs; | |||
| std::vector<void *> output_data_addrs; | |||
| std::vector<void *> workspace_addrs; | |||
| // pack all addresses into a vector. | |||
| (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs), | |||
| [](const AddressPtr &input) -> void * { return input->addr; }); | |||
| (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), | |||
| [](const AddressPtr &output) -> void * { return output->addr; }); | |||
| uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1. | |||
| auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim); | |||
| if (func_stub == 0) { | |||
| MS_LOG(EXCEPTION) << "GenFuncStub failed."; | |||
| } | |||
| std::string stub_func = KernelManager::GetStubFuncName(kernel_pack_); | |||
| MS_LOG(DEBUG) << "The block_dim is:" << block_dim; | |||
| TbeTaskInfoPtr task_info_ptr = make_shared<ge::model_runner::TbeTaskInfo>( | |||
| stream_id, stub_func, block_dim, args, args_size, sm_desc, binary, binary_size, meta_data, input_data_addrs, | |||
| output_data_addrs, workspace_addrs); | |||
| return {task_info_ptr}; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "kernel/ascend_kernel_mod.h" | |||
| #include "kernel/tbe/tbe_utils.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class AkgKernelMod : public AscendKernelMod { | |||
| public: | |||
| explicit AkgKernelMod(const KernelPackPtr &kernel_pack); | |||
| ~AkgKernelMod() final {} | |||
| void SetInputSizeList(const std::vector<size_t> &size_list); | |||
| void SetOutputSizeList(const std::vector<size_t> &size_list); | |||
| void SetWorkspaceSizeList(const std::vector<size_t> &size_list); | |||
| const std::vector<size_t> &GetInputSizeList() const override; | |||
| const std::vector<size_t> &GetOutputSizeList() const override; | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override; | |||
| std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uint32_t stream_id) override; | |||
| private: | |||
| KernelPackPtr kernel_pack_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| using AkgKernelModPtr = std::shared_ptr<AkgKernelMod>; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ | |||
| @@ -18,7 +18,7 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "kernel/kernel.h" | |||
| #include "kernel/akg/akgkernelbuild.h" | |||
| #include "kernel/akg/akg_kernel_build.h" | |||
| #include "kernel/akg/gpu/akg_gpu_kernel_mod.h" | |||
| #include "common/utils.h" | |||
| @@ -23,6 +23,11 @@ | |||
| #include "nlohmann/json.hpp" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "common/utils.h" | |||
| #include "ir/manager.h" | |||
| #include "ir/meta_tensor.h" | |||
| #include "ir/func_graph.h" | |||
| #include "operator/ops.h" | |||
| #include "utils/graph_utils.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| @@ -48,12 +53,6 @@ const std::map<TypeId, std::string> type_id_str_map = { | |||
| {TypeId::kNumberTypeBool, "bool"}, | |||
| }; | |||
| const std::map<std::string, std::string> DATATYPE_STRING_MAP{ | |||
| {"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"}, | |||
| {"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"}, | |||
| {"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool_", "bool"}, {"Float64", "double"}, | |||
| }; | |||
| const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = { | |||
| {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"}, | |||
| {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"}, | |||
| @@ -243,14 +242,6 @@ TypeId DtypeToTypeId(const std::string &dtypes) { | |||
| } | |||
| } | |||
| std::string Dtype2String(const std::string &dtypes) { | |||
| auto iter = DATATYPE_STRING_MAP.find(dtypes); | |||
| if (iter == DATATYPE_STRING_MAP.end()) { | |||
| MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes; | |||
| } | |||
| return iter->second; | |||
| } | |||
| std::string TypeId2String(TypeId type_id) { | |||
| auto iter = type_id_str_map.find(type_id); | |||
| if (iter == type_id_str_map.end()) { | |||
| @@ -361,7 +352,7 @@ bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou | |||
| output_num = 1; | |||
| } else { | |||
| if (output_idx < real_output_num) { | |||
| MS_LOG(INFO) << "Set output kernel builder info, output type is optional, output index is :" << output_idx; | |||
| MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx; | |||
| output_num = 1; | |||
| } | |||
| } | |||
| @@ -403,7 +394,7 @@ void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBu | |||
| } | |||
| if (imply_type == kAKG) { | |||
| builder->SetKernelType(AUTO_DIFF_KERNEL); | |||
| builder->SetKernelType(AKG_KERNEL); | |||
| } else if (imply_type == kAICPU) { | |||
| builder->SetKernelType(AICPU_KERNEL); | |||
| } else { | |||
| @@ -634,5 +625,256 @@ void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradie | |||
| } | |||
| unique_grad->indices_size_ = unique_indices_size + 1; | |||
| } | |||
| std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| if (index >= AnfAlgo::GetInputTensorNum(anf_node)) { | |||
| MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs."; | |||
| } | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| return AnfAlgo::VisitKernel(anf_node, 0); | |||
| } else { | |||
| return AnfAlgo::VisitKernel(anf_node->cast<CNodePtr>()->input(index + 1), 0); | |||
| } | |||
| } | |||
| std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list, | |||
| const std::vector<AnfNodePtr> &input_list) { | |||
| std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> input_index; | |||
| for (size_t i = 0; i < input_list.size(); ++i) { | |||
| auto const &input = input_list[i]; | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| bool found = false; | |||
| // using NodeUsersMap = std::unordered_map<AnfNodePtr, std::set<std::pair<AnfNodePtr, int>>>; | |||
| auto mng = input->func_graph()->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| const NodeUsersMap &users = mng->node_users(); | |||
| auto input_users = users.find(input); | |||
| if (input_users == users.end() || input_users->second.empty()) { | |||
| MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of [" | |||
| << input->func_graph()->ToString() << "] has no users."; | |||
| } | |||
| for (auto const &input_user : input_users->second) { | |||
| for (auto const &anf_node : node_list) { | |||
| if (anf_node != input_user.first) { | |||
| continue; | |||
| } | |||
| std::vector<int> dyn_input_sizes; | |||
| auto prim = AnfAlgo::GetCNodePrimitive(anf_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| if (prim->GetAttr(kAttrDynInputSizes) != nullptr) { | |||
| dyn_input_sizes = GetValue<const std::vector<int>>(prim->GetAttr(kAttrDynInputSizes)); | |||
| } | |||
| if (dyn_input_sizes.empty()) { | |||
| input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0))); | |||
| found = true; | |||
| break; | |||
| } else { | |||
| int used_as_idx = input_user.second - 1; | |||
| int accum_idx = 0; | |||
| size_t dyn_i = 0; | |||
| for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) { | |||
| accum_idx += dyn_input_sizes[dyn_i]; | |||
| if (used_as_idx < accum_idx) { | |||
| input_index.push_back(std::make_pair( | |||
| anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i]))))); | |||
| break; | |||
| } | |||
| } | |||
| if (dyn_i != dyn_input_sizes.size()) { | |||
| found = true; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (found) { | |||
| break; | |||
| } | |||
| } | |||
| if (!found) { | |||
| MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of [" | |||
| << input->func_graph()->ToString() << "] found no related kernel info."; | |||
| } | |||
| } | |||
| return input_index; | |||
| } | |||
| std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list, | |||
| const std::vector<AnfNodePtr> &input_list, | |||
| const std::vector<AnfNodePtr> &output_list) { | |||
| std::vector<std::pair<AnfNodePtr, size_t>> output_index; | |||
| for (size_t i = 0; i < output_list.size(); ++i) { | |||
| auto const &output = output_list[i]; | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| bool found = false; | |||
| auto pree_node = AnfAlgo::VisitKernel(output, 0); | |||
| auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first); | |||
| if (pos != std::end(node_list)) { | |||
| output_index.push_back(pree_node); | |||
| continue; | |||
| } | |||
| auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first); | |||
| if (ret != std::end(input_list)) { | |||
| output_index.push_back(std::make_pair(pree_node.first, 0)); | |||
| found = true; | |||
| } | |||
| if (!found) { | |||
| MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of [" | |||
| << output->func_graph()->ToString() << "] found no related kernel info."; | |||
| } | |||
| } | |||
| return output_index; | |||
| } | |||
| void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) { | |||
| MS_EXCEPTION_IF_NULL(node_list); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return()); | |||
| for (auto const &node : node_lists) { | |||
| if (!AnfAlgo::IsRealKernel(node) || !node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) { | |||
| node_list->push_back(node); | |||
| } | |||
| } | |||
| } | |||
| void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list, | |||
| std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) { | |||
| MS_EXCEPTION_IF_NULL(node_list); | |||
| MS_EXCEPTION_IF_NULL(input_list); | |||
| MS_EXCEPTION_IF_NULL(output_list); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| GetValidKernelNodes(func_graph, node_list); | |||
| auto parameters = func_graph->parameters(); | |||
| input_list->insert(input_list->begin(), parameters.begin(), parameters.end()); | |||
| auto func_output = func_graph->output(); | |||
| MS_EXCEPTION_IF_NULL(func_output); | |||
| if (func_output->isa<CNode>()) { | |||
| // multi output. | |||
| auto cnode = func_output->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto input0 = cnode->input(kAnfPrimitiveIndex); | |||
| MS_EXCEPTION_IF_NULL(input0); | |||
| if (IsPrimitive(input0, prim::kPrimMakeTuple)) { | |||
| for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) { | |||
| auto input_node = cnode->input(input_idx); | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first); | |||
| } | |||
| } else { | |||
| // single output. | |||
| output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first); | |||
| } | |||
| } else { | |||
| // single output. | |||
| output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first); | |||
| } | |||
| } | |||
| bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| MS_EXCEPTION_IF_NULL(node_json); | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (input_idx + 1 >= cnode->size()) { | |||
| MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of [" | |||
| << cnode->inputs().size() << "][" << cnode->DebugString() << "]"; | |||
| } | |||
| auto input_node = cnode->input(input_idx + 1); | |||
| if (!IsValueNode<tensor::Tensor>(input_node)) { | |||
| return false; | |||
| } | |||
| auto tensor = GetValueNode<tensor::TensorPtr>(input_node); | |||
| if (tensor == nullptr) { | |||
| return false; | |||
| } | |||
| auto type_id = tensor->data_type(); | |||
| auto *data = tensor->data_c(); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| if (tensor->DataDim() > 1 || tensor->DataSize() != 1) { | |||
| // not const tensor. | |||
| MS_LOG(WARNING) << "We take first value of tensor whose datasize != 1, [" << input_node->DebugString(2) << "]"; | |||
| } | |||
| if (type_id == kFloat32->type_id()) { | |||
| float *val = static_cast<float *>(data); | |||
| MS_EXCEPTION_IF_NULL(val); | |||
| (*node_json)["value"] = val[0]; | |||
| MS_LOG(DEBUG) << "Value of tensor[" << cnode->DebugString() << "] is [float32][" << *val << "]."; | |||
| return true; | |||
| } else if (type_id == kFloat16->type_id()) { | |||
| float16 *val = static_cast<float16 *>(data); | |||
| MS_EXCEPTION_IF_NULL(val); | |||
| (*node_json)["value"] = static_cast<float>(val[0]); | |||
| MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [float16][" << *val << "]."; | |||
| return true; | |||
| } else if (type_id == kInt32->type_id()) { | |||
| int *val = static_cast<int *>(data); | |||
| MS_EXCEPTION_IF_NULL(val); | |||
| (*node_json)["value"] = val[0]; | |||
| MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [int32][" << *val << "]."; | |||
| return true; | |||
| } | |||
| MS_LOG(ERROR) << "Unknown value type of tensor[" << cnode->DebugString() << "]"; | |||
| return false; | |||
| } | |||
| void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node_list); | |||
| auto output = func_graph->output(); | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| if (AnfAlgo::IsRealKernel(output)) { | |||
| // single output. | |||
| node_list->push_back(std::make_pair(output, 0)); | |||
| return; | |||
| } else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { | |||
| auto output_cnode = output->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(output_cnode); | |||
| // multi output. | |||
| auto &inputs = output_cnode->inputs(); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0); | |||
| node_list->push_back(in_with_idx); | |||
| } | |||
| return; | |||
| } | |||
| MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2) | |||
| << " of graph: " << func_graph->ToString(); | |||
| } | |||
| bool IsWeightBoundary(const AnfNodePtr &node) { | |||
| if (node->isa<ValueNode>()) { | |||
| return true; | |||
| } | |||
| if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -20,9 +20,12 @@ | |||
| #include <dirent.h> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include <nlohmann/json.hpp> | |||
| #include "kernel/kernel.h" | |||
| #include "kernel/oplib/opinfo.h" | |||
| #include "kernel/kernel_build_info.h" | |||
| @@ -79,13 +82,11 @@ bool CheckCache(const std::string &kernel_name); | |||
| KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor); | |||
| KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor); | |||
| TypeId DtypeToTypeId(const std::string &dtypes); | |||
| std::string Dtype2String(const std::string &dtypes); | |||
| std::string Dtype2ShortType(const std::string &dtypes); | |||
| std::string TypeId2String(TypeId type_id); | |||
| size_t GetDtypeNbyte(const std::string &dtypes); | |||
| bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor, | |||
| std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list); | |||
| bool IsAtomicNode(const CNodePtr &kernel_node); | |||
| void SaveJsonInfo(const std::string &json_name, const std::string &info); | |||
| std::string GetProcessor(const AnfNodePtr &anf_node); | |||
| bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b); | |||
| @@ -94,6 +95,18 @@ void DeduplicateIndexedSlices(const SparseGradient &origin_sparse_grad, SparseGr | |||
| size_t outer_dim); | |||
| void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, | |||
| size_t outer_dim); | |||
| std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index); | |||
| std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list, | |||
| const std::vector<AnfNodePtr> &input_list); | |||
| std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list, | |||
| const std::vector<AnfNodePtr> &input_list, | |||
| const std::vector<AnfNodePtr> &output_list); | |||
| void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list, | |||
| std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list); | |||
| void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list); | |||
| bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json); | |||
| void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list); | |||
| bool IsWeightBoundary(const AnfNodePtr &node); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -17,7 +17,7 @@ | |||
| #include <fstream> | |||
| #include "mindspore/ccsrc/kernel/kernel.h" | |||
| #include "kernel/kernel.h" | |||
| #include "kernel/akg/akgkernelbuild.h" | |||
| #include "kernel/akg/akg_kernel_build.h" | |||
| #include "nlohmann/json.hpp" | |||
| #include "securec/include/securec.h" | |||
| #include "pipeline/parse/python_adapter.h" | |||
| @@ -27,7 +27,7 @@ | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AUTO_DIFF_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL }; | |||
| enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AKG_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL }; | |||
| namespace kernel { | |||
| @@ -21,6 +21,7 @@ | |||
| #include "kernel/rts/rt_kernel_info.h" | |||
| #include "kernel/hccl/hccl_kernel_metadata.h" | |||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" | |||
| #include "kernel/akg/akg_kernel_metadata.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| @@ -59,10 +60,14 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node, | |||
| } | |||
| } | |||
| } // namespace | |||
| void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | |||
| void KernelQueryAll(const CNodePtr &kernel_node, | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||
| TbeMetadataInfo(kernel_node, kernel_info_list); | |||
| if (kernel_info_list->empty()) { | |||
| AicpuMetadataInfo(kernel_node, kernel_info_list); | |||
| if (!kernel_info_list->empty()) { | |||
| @@ -82,6 +87,28 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel | |||
| if (kernel_info_list->empty()) { | |||
| MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!"; | |||
| } | |||
| } | |||
| void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list, | |||
| KernelType kernel_type) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||
| std::string op_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| switch (kernel_type) { | |||
| case KernelType::AKG_KERNEL: | |||
| AkgMetadataInfo(kernel_node, kernel_info_list); | |||
| break; | |||
| default: | |||
| KernelQueryAll(kernel_node, kernel_info_list); | |||
| break; | |||
| } | |||
| if (kernel_info_list->empty()) { | |||
| MS_EXCEPTION(NotExistsError) << "Op[" << kernel_node->DebugString() << "] kernel query fail!"; | |||
| } | |||
| // check output | |||
| FilterInvalidKernelInfo(kernel_node, kernel_info_list); | |||
| } | |||
| @@ -25,7 +25,8 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); | |||
| void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list, | |||
| KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); | |||
| void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); | |||
| bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); | |||
| bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); | |||
| @@ -272,8 +272,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool is_gpu = (context->device_target() == kGPUDevice); | |||
| if ((is_gpu && (imply_type == kTBE || imply_type == kAICPU)) || | |||
| (!is_gpu && (imply_type != kTBE && imply_type != kAICPU))) { | |||
| if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) { | |||
| MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) | |||
| << ", current op num: " << op_info_.size(); | |||
| return nullptr; | |||
| @@ -347,7 +347,7 @@ static int TypeStrToDstType(const std::string &type_str) { | |||
| ret = 4; | |||
| } else if (type_str == "UInt64") { | |||
| ret = 10; | |||
| } else if (type_str == "Bool_") { | |||
| } else if (type_str == "Bool") { | |||
| ret = 12; | |||
| } else { | |||
| MS_LOG(INFO) << "Error type str is invailed: " << type_str; | |||
| @@ -51,7 +51,7 @@ const std::map<TypeId, std::string> type_id_str_maps = { | |||
| const std::map<std::string, std::string> type_str_maps = { | |||
| {"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"}, | |||
| {"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"}, | |||
| {"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool_", "int8"}, {"Float64", "float64"}, | |||
| {"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool", "int8"}, {"Float64", "float64"}, | |||
| }; | |||
| const std::unordered_map<std::string, size_t> type_nbyte_maps = { | |||
| @@ -334,8 +334,8 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL | |||
| FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { | |||
| FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | |||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||
| ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||
| ptrGraph->debug_info()->set_name("hyper_map"); | |||
| AnfNodePtr ptrFnArg = nullptr; | |||
| @@ -389,7 +389,7 @@ FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tu | |||
| MS_EXCEPTION_IF_NULL(a_tuple); | |||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | |||
| ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| ret->debug_info()->set_name("tail"); | |||
| AnfNodePtr ptrTup = ret->add_parameter(); | |||
| @@ -409,7 +409,7 @@ FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list | |||
| MS_EXCEPTION_IF_NULL(a_list); | |||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | |||
| ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| ret->debug_info()->set_name("tail"); | |||
| AnfNodePtr ptrList = ret->add_parameter(); | |||
| @@ -481,10 +481,10 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg | |||
| grads.push_back(b->NewCNode({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)})); | |||
| } | |||
| b->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| b->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| b->set_output(b->NewCNode(grads)); | |||
| fg->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| fg->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)})); | |||
| (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple)); | |||
| return fg; | |||
| @@ -504,7 +504,7 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, | |||
| const std::vector<AnfNodePtr> ¶ms_list, const std::vector<AnfNodePtr> &args, | |||
| bool applyJ) { | |||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | |||
| ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| auto weights_node = weights; | |||
| if (weights == nullptr && !args.empty()) { | |||
| @@ -625,7 +625,7 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp | |||
| std::ostringstream ss; | |||
| ss << "grad{" << nparam << "}"; | |||
| dfBuilder->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| dfBuilder->debug_info()->set_name(ss.str()); | |||
| ParameterPtr param_graph = dfBuilder->add_parameter(); | |||
| @@ -671,7 +671,7 @@ FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_lis | |||
| } | |||
| FuncGraphPtr fg_ptr = std::make_shared<FuncGraph>(); | |||
| fg_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| fg_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| fg_ptr->debug_info()->set_name("list_map"); | |||
| AnfNodePtr fn = fg_ptr->add_parameter(); | |||
| @@ -741,7 +741,7 @@ void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr | |||
| // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts) | |||
| FuncGraphPtr fgtrue_ptr = std::make_shared<FuncGraph>(); | |||
| fgtrue_ptr->debug_info()->set_name("ftrue"); | |||
| fgtrue_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| fgtrue_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNode({NewValueNode(fgnext_ptr), fn, resl}); | |||
| auto inputs = fgtrue_output_cnode->inputs(); | |||
| @@ -751,7 +751,7 @@ void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr | |||
| FuncGraphPtr fgfalse_ptr = std::make_shared<FuncGraph>(); | |||
| fgfalse_ptr->debug_info()->set_name("ffalse"); | |||
| fgfalse_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| fgfalse_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| fgfalse_ptr->set_output(resl); | |||
| AnfNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")), | |||
| @@ -808,7 +808,7 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_li | |||
| } | |||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | |||
| ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| AnfNodePtr p_tup_a = ret->add_parameter(); | |||
| AnfNodePtr p_tup_b = ret->add_parameter(); | |||
| @@ -912,7 +912,7 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_ | |||
| GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value); | |||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | |||
| ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| AnfNodePtr p_tuple = ret->add_parameter(); | |||
| (void)ret->add_parameter(); | |||
| @@ -941,7 +941,7 @@ FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &ar | |||
| AbstractBasePtrList branches = branches_abs->elements(); | |||
| if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) { | |||
| FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); | |||
| ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| AnfNodePtr functions = ret_graph->add_parameter(); | |||
| auto index = ret_graph->add_parameter(); | |||
| @@ -304,7 +304,7 @@ FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrLi | |||
| } | |||
| auto new_cnode = BuildNewCNode(func_graph, name_, function_, args_spec_list, func_graph->parameters()); | |||
| func_graph->set_output(new_cnode); | |||
| func_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| func_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| return func_graph; | |||
| } | |||
| } // namespace prim | |||
| @@ -35,7 +35,7 @@ FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList & | |||
| MS_EXCEPTION_IF_NULL(arg0_list); | |||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | |||
| ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| ret->debug_info()->set_name("append"); | |||
| AnfNodePtr arg0_node = ret->add_parameter(); | |||
| @@ -51,8 +51,8 @@ AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &f | |||
| FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) { | |||
| // Generate func for leaf nodes | |||
| FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | |||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||
| ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||
| ptrGraph->debug_info()->set_name("map"); | |||
| AnfNodePtr ptrFnArg = nullptr; | |||
| if (fn_leaf_ == nullptr) { | |||
| @@ -237,8 +237,8 @@ AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, c | |||
| FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) { | |||
| FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | |||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||
| ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||
| ptrGraph->debug_info()->set_name("map"); | |||
| AnfNodePtr ptrFnArg = nullptr; | |||
| @@ -51,7 +51,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_ | |||
| (void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0); | |||
| auto ret_graph = std::make_shared<FuncGraph>(); | |||
| ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| AnfNodePtr fnNode = ret_graph->add_parameter(); | |||
| std::vector<AnfNodePtr> elems; | |||
| @@ -57,7 +57,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spe | |||
| return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size()); | |||
| }); | |||
| FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); | |||
| ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| for (size_t idx = 0; idx < args_spec_list.size(); idx++) { | |||
| (void)ret_graph->add_parameter(); | |||
| } | |||
| @@ -50,6 +50,12 @@ const PrimitivePtr kPrimBoolNot = std::make_shared<Primitive>("bool_not"); | |||
| const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and"); | |||
| const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or"); | |||
| const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq"); | |||
| const PrimitivePtr kPrimGreater = std::make_shared<Primitive>("Greater"); | |||
| const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual"); | |||
| const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less"); | |||
| const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual"); | |||
| const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal"); | |||
| const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual"); | |||
| // Type introspection | |||
| const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof"); | |||
| @@ -166,17 +172,20 @@ const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul"); | |||
| const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum"); | |||
| const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum"); | |||
| const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square"); | |||
| const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal"); | |||
| const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less"); | |||
| const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual"); | |||
| const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum"); | |||
| const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd"); | |||
| const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar"); | |||
| const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd"); | |||
| const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub"); | |||
| const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow"); | |||
| const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv"); | |||
| const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt"); | |||
| const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal"); | |||
| const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims"); | |||
| // NN | |||
| const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | |||
| const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax"); | |||
| const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax"); | |||
| const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad"); | |||
| const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh"); | |||
| @@ -253,6 +262,7 @@ const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict"); | |||
| const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | |||
| const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast"); | |||
| const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant"); | |||
| const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat"); | |||
| // Comm ops | |||
| const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | |||
| @@ -59,6 +59,12 @@ extern const PrimitivePtr kPrimBoolNot; | |||
| extern const PrimitivePtr kPrimBoolAnd; | |||
| extern const PrimitivePtr kPrimBoolOr; | |||
| extern const PrimitivePtr kPrimBoolEq; | |||
| extern const PrimitivePtr kPrimGreater; | |||
| extern const PrimitivePtr kPrimGreaterEqual; | |||
| extern const PrimitivePtr kPrimLess; | |||
| extern const PrimitivePtr kPrimLessEqual; | |||
| extern const PrimitivePtr kPrimEqual; | |||
| extern const PrimitivePtr kPrimNotEqual; | |||
| // Type introspection | |||
| extern const PrimitivePtr kPrimTypeOf; | |||
| @@ -157,6 +163,10 @@ extern const PrimitivePtr KPrimTransData; | |||
| extern const PrimitivePtr kPrimNMSWithMask; | |||
| extern const PrimitivePtr kPrimPad; | |||
| extern const PrimitivePtr kPrimArgMaxWithValue; | |||
| extern const PrimitivePtr kPrimRealDiv; | |||
| extern const PrimitivePtr kPrimSqrt; | |||
| extern const PrimitivePtr kPrimReciprocal; | |||
| extern const PrimitivePtr kPrimExpandDims; | |||
| // Maths | |||
| extern const PrimitivePtr kPrimTensorAdd; | |||
| @@ -183,9 +193,11 @@ extern const PrimitivePtr kPrimCumProd; | |||
| extern const PrimitivePtr kPrimSubscalar; | |||
| extern const PrimitivePtr kPrimInplaceAdd; | |||
| extern const PrimitivePtr kPrimInplaceSub; | |||
| extern const PrimitivePtr kPrimPow; | |||
| // NN | |||
| extern const PrimitivePtr kPrimFlatten; | |||
| extern const PrimitivePtr kPrimSoftmax; | |||
| extern const PrimitivePtr kPrimLogSoftmax; | |||
| extern const PrimitivePtr kPrimLogSoftmaxGrad; | |||
| extern const PrimitivePtr kPrimApplyCenteredRMSProp; | |||
| @@ -263,6 +275,7 @@ extern const PrimitivePtr kPrimInDict; | |||
| extern const PrimitivePtr kPrimNotInDict; | |||
| extern const PrimitivePtr kPrimMixedPrecisionCast; | |||
| extern const PrimitivePtr kPrimIsConsant; | |||
| extern const PrimitivePtr kPrimEquivFormat; | |||
| // Comm ops | |||
| extern const PrimitivePtr kPrimAllReduce; | |||
| @@ -45,10 +45,19 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas | |||
| : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) { | |||
| TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal_graph->debug_info())); | |||
| k_graph_ = std::make_shared<FuncGraph>(); | |||
| if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||
| std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||
| k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name)); | |||
| } | |||
| TraceManager::EndTrace(); | |||
| TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info())); | |||
| tape_ = std::make_shared<FuncGraph>(); | |||
| // Add "_Grad" postfix | |||
| if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||
| std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) + "_Grad"; | |||
| tape_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name)); | |||
| } | |||
| TraceManager::EndTrace(); | |||
| dout_ = tape_->add_parameter(); | |||
| @@ -368,7 +377,7 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { | |||
| (void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg))); | |||
| (void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal))); | |||
| // Reset defer_inline to enable successive inlining | |||
| primal->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false); | |||
| primal->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false); | |||
| auto functor = std::make_shared<DFunctor>(primal, resources_); | |||
| functor->Init(); | |||
| @@ -37,7 +37,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt | |||
| auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) { | |||
| if (MsContext::GetInstance()->is_multi_graph_sink()) { | |||
| if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { | |||
| f->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| } | |||
| } | |||
| }; | |||
| @@ -78,7 +78,10 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(cons); | |||
| auto dt = data->abstract(); | |||
| MS_EXCEPTION_IF_NULL(dt); | |||
| if (dt == nullptr) { | |||
| return nullptr; | |||
| } | |||
| if (!dt->isa<AbstractClass>()) { | |||
| MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << "."; | |||
| } | |||
| @@ -0,0 +1,157 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "optimizer/graph_kernel_reuse.h" | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <string> | |||
| #include "./common.h" | |||
| #include "utils/graph_utils.h" | |||
| namespace mindspore { | |||
| /* namespace to support opt */ | |||
| namespace opt { | |||
| bool GraphKernelReuse::CompareNode(const AnfNodePtr a, const AnfNodePtr b) { | |||
| if (a->abstract() && b->abstract()) { | |||
| auto a_type = a->abstract()->GetTypeTrack(); | |||
| auto b_type = b->abstract()->GetTypeTrack(); | |||
| if (a_type != b_type) { | |||
| return false; | |||
| } | |||
| auto a_shape = a->abstract()->GetShapeTrack(); | |||
| auto b_shape = b->abstract()->GetShapeTrack(); | |||
| if (a_shape != nullptr && a_shape == b_shape) { | |||
| return true; | |||
| } | |||
| if (a_shape != nullptr && b_shape != nullptr && a_shape->isa<abstract::Shape>() && | |||
| b_shape->isa<abstract::Shape>()) { | |||
| return a_shape->cast<abstract::ShapePtr>()->shape() == b_shape->cast<abstract::ShapePtr>()->shape(); | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool GraphKernelReuse::DoReplace(const FuncGraphManagerPtr manager) { | |||
| bool changed = false; | |||
| auto fgs = manager->func_graphs(); | |||
| for (FuncGraphPtr &fg : fgs) { | |||
| if (!fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||
| continue; | |||
| } | |||
| std::string key = GetValue<std::string>(fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||
| if (graph_kernel_ops.find(key) != graph_kernel_ops.end()) { | |||
| if (find(graph_kernel_ops[key].begin(), graph_kernel_ops[key].end(), fg) == graph_kernel_ops[key].end()) { | |||
| FuncGraphPtr new_fg = nullptr; | |||
| for (auto &cfg : graph_kernel_ops[key]) { | |||
| // If two graphs have different size then continue | |||
| auto fg_topos = TopoSort(fg->get_return()); | |||
| auto cfg_topos = TopoSort(cfg->get_return()); | |||
| if (fg_topos.size() != cfg_topos.size()) { | |||
| continue; | |||
| } | |||
| // Compare const tensor | |||
| bool has_same = true; | |||
| for (size_t i = 0; i < fg_topos.size(); ++i) { | |||
| if (IsValueNode<tensor::Tensor>(fg_topos[i])) { | |||
| if (!IsValueNode<tensor::Tensor>(cfg_topos[i])) { | |||
| has_same = false; | |||
| break; | |||
| } | |||
| auto tensor1 = GetValueNode<tensor::TensorPtr>(fg_topos[i]); | |||
| auto tensor2 = GetValueNode<tensor::TensorPtr>(cfg_topos[i]); | |||
| if (!tensor1->ValueEqual(*tensor2)) { | |||
| has_same = false; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (!has_same) { | |||
| continue; | |||
| } | |||
| auto fg_input = fg->parameters(); | |||
| auto cfg_input = cfg->parameters(); | |||
| if (fg_input.size() != cfg_input.size()) { | |||
| continue; | |||
| } | |||
| // Compare input | |||
| for (size_t i = 0; i < fg_input.size(); ++i) { | |||
| if (!CompareNode(fg_input[i], cfg_input[i])) { | |||
| has_same = false; | |||
| break; | |||
| } | |||
| } | |||
| if (!has_same) { | |||
| continue; | |||
| } | |||
| // Compare output | |||
| if (!CompareNode(fg->output(), cfg->output())) { | |||
| continue; | |||
| } | |||
| // Find reusable fg | |||
| new_fg = cfg; | |||
| break; | |||
| } | |||
| if (new_fg != nullptr) { | |||
| // Replace current fg with existing fg | |||
| auto users = fg->func_graph_cnodes_index(); | |||
| for (auto &iter : users) { | |||
| auto cnode = iter.first->first->cast<CNodePtr>(); | |||
| auto new_input = cnode->inputs(); | |||
| auto main_graph = cnode->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(main_graph); | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimPartial)) { | |||
| new_input[1] = NewValueNode(new_fg); | |||
| } else { | |||
| new_input[0] = NewValueNode(new_fg); | |||
| } | |||
| auto new_cnode = main_graph->NewCNode(new_input); | |||
| manager->Replace(iter.first->first, new_cnode); | |||
| changed = true; | |||
| } | |||
| } else { | |||
| // Add current fg to map | |||
| graph_kernel_ops[key].push_back(fg); | |||
| } | |||
| } | |||
| } else { | |||
| graph_kernel_ops[key] = {fg}; | |||
| } | |||
| } | |||
| return changed; | |||
| } | |||
| bool GraphKernelReuse::ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager) { | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->AddFuncGraph(root); | |||
| return DoReplace(manager); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H | |||
| #include <mindspore/ccsrc/session/anf_runtime_algorithm.h> | |||
| #include <unordered_map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "optimizer/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| // Common subexpression elimination. | |||
| class GraphKernelReuse { | |||
| public: | |||
| GraphKernelReuse() : count(0) {} | |||
| virtual ~GraphKernelReuse() = default; | |||
| bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { | |||
| bool chg = ReuseGraphKernel(root, optimizer->resource()->manager()); | |||
| return chg; | |||
| } | |||
| bool CompareNode(const AnfNodePtr a, const AnfNodePtr other); | |||
| bool DoReplace(const FuncGraphManagerPtr manager); | |||
| bool ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager); | |||
| private: | |||
| std::unordered_map<std::string, std::vector<FuncGraphPtr>> graph_kernel_ops; | |||
| int count; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H | |||
| @@ -41,6 +41,8 @@ | |||
| #include "optimizer/irpass/incorporate_call.h" | |||
| #include "optimizer/irpass/grad_var_prepare.h" | |||
| #include "optimizer/irpass/param_replace.h" | |||
| #include "optimizer/irpass/mark_interface_fusion.h" | |||
| #include "optimizer/opt.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -48,7 +50,7 @@ namespace irpass { | |||
| OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", | |||
| {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, | |||
| prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); | |||
| prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); | |||
| special_op_eliminate_ = | |||
| MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | |||
| {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, | |||
| @@ -90,7 +92,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| replace_refkey_by_param_ = | |||
| MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM); | |||
| replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam); | |||
| // Gradient transforms | |||
| expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ); | |||
| minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem); | |||
| @@ -115,6 +116,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| // Incorporation | |||
| incorporate_getitem_set_ = | |||
| MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem); | |||
| incorporate_getitem_from_param_ = | |||
| MakeSubstitution(IncorporateGetitemFromParam(), "incorporate_getitem_from_param", IsCNodeGraphKernel); | |||
| incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup); | |||
| incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup); | |||
| @@ -124,6 +127,17 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| // Convert | |||
| print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint); | |||
| // Unused parameter eliminate | |||
| unused_parameter_eliminate_ = | |||
| MakeSubstitution(UnusedParasEliminater(), "unused_parameter_eliminate", IsCNodeGraphKernel); | |||
| unused_output_eliminate_ = MakeSubstitution(UnusedOutputEliminater(), "unused_output_eliminate", IsCNodeGraphKernel); | |||
| // AddN eliminate | |||
| addn_eliminate_ = MakeSubstitution(AddNEliminater(), "addn_eliminate", IsCNodeGraphKernel); | |||
| // Mark interface fusion | |||
| mark_interface_fusion_ = MakeSubstitution(MarkInterfaceFusion(), "mark_interface_fusion", prim::kPrimSelect); | |||
| } | |||
| ResolveIRPassLib::ResolveIRPassLib() { | |||
| @@ -84,6 +84,7 @@ class OptimizeIRPassLib { | |||
| // Incorporation | |||
| SubstitutionPtr incorporate_getitem_set_; | |||
| SubstitutionPtr incorporate_getitem_from_param_; | |||
| SubstitutionPtr incorporate_call_; | |||
| SubstitutionPtr incorporate_call_switch_; | |||
| @@ -92,6 +93,16 @@ class OptimizeIRPassLib { | |||
| // Convert | |||
| SubstitutionPtr print_tuple_wrapper_; | |||
| // Unused parameter eliminate | |||
| SubstitutionPtr unused_parameter_eliminate_; | |||
| SubstitutionPtr unused_output_eliminate_; | |||
| // AddN eliminate | |||
| SubstitutionPtr addn_eliminate_; | |||
| // Fusion | |||
| SubstitutionPtr mark_interface_fusion_; | |||
| }; | |||
| // the collection of irpass for resolve action | |||
| @@ -145,6 +156,23 @@ inline bool IsCNodeGraph(const AnfNodePtr &node) { | |||
| return IsValueNode<FuncGraph>(inp0); | |||
| } | |||
| // Check if CNode Input 0 is Func Graph of graph kernel. | |||
| inline bool IsCNodeGraphKernel(const AnfNodePtr &node) { | |||
| if (node == nullptr || !node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| auto inp0 = node->cast<CNodePtr>()->input(0); | |||
| if (IsValueNode<FuncGraph>(inp0)) { | |||
| auto fg = GetValueNode<FuncGraphPtr>(inp0); | |||
| if (fg == nullptr) { | |||
| return false; | |||
| } | |||
| return fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||
| } | |||
| return false; | |||
| } | |||
| // Check if CNode Input 0 is CNode | |||
| inline bool IsCNodeDup(const AnfNodePtr &node) { | |||
| if (node == nullptr || !node->isa<CNode>()) { | |||
| @@ -83,6 +83,216 @@ class MultiplyByZeroOrOne : public AnfVisitor { | |||
| AnfNodePtr x_{nullptr}; | |||
| }; | |||
| // Support class used for checking if all values of a Tensor are equal `check_value_` | |||
| // Supported data types: double, float/float32, int/int32 | |||
| class CheckTensorConstant { | |||
| public: | |||
| explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {} | |||
| ~CheckTensorConstant() = default; | |||
| bool IsTensorConstant(const ValuePtr &value) { | |||
| if (!value->isa<tensor::Tensor>()) { | |||
| return false; | |||
| } | |||
| auto tensor_ptr = dyn_cast<tensor::Tensor>(value); | |||
| TypeId tensor_type = tensor_ptr->Dtype()->type_id(); | |||
| if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { | |||
| float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c()); | |||
| for (int i = 0; i < tensor_ptr->DataSize(); i++) { | |||
| if (fabs(data2[i] - check_value_) > FLT_EPSILON) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } else if (tensor_type == TypeId::kNumberTypeFloat64) { | |||
| double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c()); | |||
| for (int i = 0; i < tensor_ptr->DataSize(); i++) { | |||
| if (fabs(data2[i] - check_value_) > DBL_EPSILON) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { | |||
| int *data2 = reinterpret_cast<int *>(tensor_ptr->data_c()); | |||
| for (int i = 0; i < tensor_ptr->DataSize(); i++) { | |||
| if (data2[i] != check_value_) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| // Un-support Data Types | |||
| return false; | |||
| } | |||
| bool IsTensorScalarConstant(const ValuePtr &value) { | |||
| if (!value->isa<tensor::Tensor>()) { | |||
| return false; | |||
| } | |||
| auto tensor_ptr = dyn_cast<tensor::Tensor>(value); | |||
| if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) { | |||
| return false; | |||
| } | |||
| return IsTensorConstant(value); | |||
| } | |||
| private: | |||
| int check_value_; | |||
| }; | |||
| // {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} | |||
| // {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} | |||
| class TensorMultiplyByZeroOrOne : public AnfVisitor { | |||
| public: | |||
| TensorMultiplyByZeroOrOne() : zero_(MakeValue(0)) {} | |||
| ~TensorMultiplyByZeroOrOne() override = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| Reset(); | |||
| AnfVisitor::Match(prim::kPrimMul)(node); | |||
| if (is_zero_) { | |||
| if (x_->func_graph() != node->func_graph()) { | |||
| return nullptr; | |||
| } | |||
| return NewTensorFilledWithData(node); | |||
| } | |||
| if (is_one_) { | |||
| return NewTensorFilledWithData(node, x_); | |||
| } | |||
| return nullptr; | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { | |||
| if (is_zero_ || is_one_) { | |||
| x_ = node; | |||
| return; | |||
| } | |||
| if (IsParam(node)) { | |||
| x_ = node; | |||
| return; | |||
| } | |||
| if (IsCNode(node)) { | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) { | |||
| is_zero_ = true; | |||
| return; | |||
| } | |||
| x_ = node; | |||
| return; | |||
| } | |||
| auto value = node->cast<ValueNodePtr>()->value(); | |||
| if (CheckTensorConstant(0).IsTensorConstant(value)) { | |||
| is_zero_ = true; | |||
| return; | |||
| } else if (CheckTensorConstant(1).IsTensorConstant(value)) { | |||
| is_one_ = true; | |||
| return; | |||
| } | |||
| x_ = node; | |||
| } | |||
| void Visit(const ValueNodePtr &vnode) override { | |||
| auto value = vnode->value(); | |||
| if (CheckTensorConstant(0).IsTensorConstant(value)) { | |||
| is_zero_ = true; | |||
| return; | |||
| } else if (CheckTensorConstant(1).IsTensorConstant(value)) { | |||
| is_one_ = true; | |||
| return; | |||
| } | |||
| x_ = vnode; | |||
| } | |||
| void Reset() { | |||
| x_ = nullptr; | |||
| is_one_ = false; | |||
| is_zero_ = false; | |||
| } | |||
| void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) { | |||
| if (!node->isa<ValueNode>()) { | |||
| return nullptr; | |||
| } | |||
| auto value = node->cast<ValueNodePtr>()->value(); | |||
| if (!value->isa<tensor::Tensor>()) { | |||
| return nullptr; | |||
| } | |||
| tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value); | |||
| return tensor_ptr->data_c(writable); | |||
| } | |||
| // Make a new tensor (when possible) with the same shape as of `node` | |||
| // If x is nullptr then fill new tensor will "0" | |||
| // If x is a tensor with empty shape then fill new tensor with the single value of x | |||
| // If x is a tensor with same shape as `node` then return x as result | |||
| AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr) { | |||
| if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) { | |||
| return nullptr; | |||
| } | |||
| auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>(); | |||
| TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); | |||
| std::vector<int> tensor_shape = tensor_abstract->shape()->shape(); | |||
| auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape); | |||
| size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); | |||
| char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true)); | |||
| if (x == nullptr) { | |||
| std::memset(data, 0, mem_size); | |||
| auto new_vnode = NewValueNode(new_tensor_ptr); | |||
| new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); | |||
| return new_vnode; | |||
| } | |||
| // x is not nullptr | |||
| if (x->isa<CNode>()) { | |||
| if ((x->abstract() == nullptr) || !x->abstract()->isa<abstract::AbstractTensor>()) { | |||
| return nullptr; | |||
| } | |||
| auto x_abstract = x->abstract()->cast<abstract::AbstractTensorPtr>(); | |||
| std::vector<int> x_shape = x_abstract->shape()->shape(); | |||
| if (x_shape != tensor_shape) { | |||
| return nullptr; | |||
| } | |||
| return x; | |||
| } | |||
| if (!x->isa<ValueNode>()) { | |||
| return nullptr; | |||
| } | |||
| auto x_value = x->cast<ValueNodePtr>()->value(); | |||
| if (!x_value->isa<tensor::Tensor>()) { | |||
| return nullptr; | |||
| } | |||
| auto x_tensor_ptr = dyn_cast<tensor::Tensor>(x_value); | |||
| if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) { | |||
| return nullptr; | |||
| } | |||
| char *source_data = reinterpret_cast<char *>(GetPointerToTensorData(x)); | |||
| if (x_tensor_ptr->DataSize() == 1) { | |||
| for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) { | |||
| memcpy(source_data, data + i * GetTypeByte(tensor_type_ptr), GetTypeByte(tensor_type_ptr)); | |||
| } | |||
| } else { | |||
| memcpy(source_data, data, mem_size); | |||
| } | |||
| auto new_vnode = NewValueNode(new_tensor_ptr); | |||
| new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); | |||
| return new_vnode; | |||
| } | |||
| private: | |||
| bool is_zero_{false}, is_one_{false}; | |||
| ValuePtr zero_; | |||
| AnfNodePtr x_{nullptr}; | |||
| }; | |||
| // {prim::kPrimScalarAdd, X, 0} | |||
| // {prim::kPrimScalarAdd, 0, X} | |||
| class AddByZero : public AnfVisitor { | |||
| @@ -101,7 +311,8 @@ class AddByZero : public AnfVisitor { | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { | |||
| if (node->isa<ValueNode>() && *GetValueNode(node) == *zero_) { | |||
| if (node->isa<ValueNode>() && | |||
| ((*GetValueNode(node) == *zero_) || CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node)))) { | |||
| is_zero_ = true; | |||
| return; | |||
| } | |||
| @@ -139,10 +350,22 @@ class TensorAddByZero : public AnfVisitor { | |||
| is_zero_ = true; | |||
| return; | |||
| } | |||
| if (node->isa<ValueNode>() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) { | |||
| is_zero_ = true; | |||
| return; | |||
| } | |||
| x_ = node; | |||
| } | |||
| void Visit(const ValueNodePtr &vnode) override { | |||
| auto value = vnode->value(); | |||
| if (CheckTensorConstant(0).IsTensorConstant(value)) { | |||
| is_zero_ = true; | |||
| return; | |||
| } | |||
| } | |||
| void Reset() { | |||
| x_ = nullptr; | |||
| is_zero_ = false; | |||
| @@ -183,29 +406,143 @@ class OptUpdateZeroTensor : public AnfVisitor { | |||
| // {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} | |||
| class ConstantDuplicateMul : public AnfVisitor { | |||
| public: | |||
| // Support function to multiply two constant tensors: partially support broadcasting shapes | |||
| template <typename T> | |||
| void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, | |||
| int out_data_size) { | |||
| T *data_1 = reinterpret_cast<T *>(in_data_1); | |||
| T *data_2 = reinterpret_cast<T *>(in_data_2); | |||
| T *data_out = new T[out_data_size]; | |||
| if (in_data_1_size == 1) { | |||
| for (int i = 0; i < out_data_size; i++) { | |||
| data_out[i] = data_1[0]; | |||
| } | |||
| } else { | |||
| for (int i = 0; i < out_data_size; i++) { | |||
| data_out[i] = data_1[i]; | |||
| } | |||
| } | |||
| if (in_data_2_size == 1) { | |||
| for (int i = 0; i < out_data_size; i++) { | |||
| data_out[i] *= data_2[0]; | |||
| } | |||
| } else { | |||
| for (int i = 0; i < out_data_size; i++) { | |||
| data_out[i] *= data_2[i]; | |||
| } | |||
| } | |||
| *out_data = reinterpret_cast<void *>(data_out); | |||
| return; | |||
| } | |||
| AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) { | |||
| if (!vnode_1->isa<ValueNode>() || !vnode_2->isa<ValueNode>() || (vnode_1->abstract() == nullptr) || | |||
| (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) { | |||
| return nullptr; | |||
| } | |||
| auto value_1 = GetValueNode(vnode_1); | |||
| auto value_2 = GetValueNode(vnode_2); | |||
| if (!value_1->isa<tensor::Tensor>() || !value_2->isa<tensor::Tensor>()) { | |||
| return nullptr; | |||
| } | |||
| auto tensor_ptr_1 = dyn_cast<tensor::Tensor>(value_1); | |||
| auto tensor_ptr_2 = dyn_cast<tensor::Tensor>(value_2); | |||
| auto tensor_1_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>(); | |||
| auto tensor_2_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>(); | |||
| auto tensor_3_abstract = node_3->abstract()->cast<abstract::AbstractTensorPtr>(); | |||
| TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); | |||
| TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); | |||
| TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); | |||
| if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || | |||
| (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { | |||
| return nullptr; | |||
| } | |||
| std::vector<int> tensor_out_shape = tensor_3_abstract->shape()->shape(); | |||
| int data_out_size = 1; | |||
| for (auto it : tensor_out_shape) { | |||
| data_out_size *= it; | |||
| } | |||
| if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { | |||
| return nullptr; | |||
| } | |||
| if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { | |||
| return nullptr; | |||
| } | |||
| void *data_out; | |||
| if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) || | |||
| (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) { | |||
| Multiply<float>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), | |||
| tensor_ptr_2->DataSize(), &data_out, data_out_size); | |||
| } else { | |||
| if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) { | |||
| Multiply<double>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), | |||
| tensor_ptr_2->DataSize(), &data_out, data_out_size); | |||
| } else { | |||
| if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) || | |||
| (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) { | |||
| Multiply<int>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), | |||
| tensor_ptr_2->DataSize(), &data_out, data_out_size); | |||
| } else { | |||
| // Un-support data types | |||
| return nullptr; | |||
| } | |||
| } | |||
| } | |||
| auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape); | |||
| size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); | |||
| char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true)); | |||
| memcpy(data, data_out, mem_size); | |||
| auto new_vnode = NewValueNode(new_tensor_ptr); | |||
| new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); | |||
| return new_vnode; | |||
| } | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| Reset(); | |||
| // {prim::kPrimMul, Tensor1, {...}} | |||
| AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); | |||
| if (vnode_ == nullptr || cnode_ == nullptr) { | |||
| if (vnode_ == nullptr || c_p_node_ == nullptr) { | |||
| return nullptr; | |||
| } | |||
| if (!IsCNode(c_p_node_)) { | |||
| return nullptr; | |||
| } | |||
| auto tensor1 = vnode_; | |||
| auto mul = cnode_; | |||
| auto mul = c_p_node_->cast<CNodePtr>(); | |||
| Reset(); | |||
| // {prim::kPrimMul, Tensor2, {...}} | |||
| AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); | |||
| if (vnode_ == nullptr || cnode_ == nullptr) { | |||
| if (vnode_ == nullptr || c_p_node_ == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto tensor2 = vnode_; | |||
| auto cnode = cnode_; | |||
| auto c_p_node = c_p_node_; | |||
| auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0)); | |||
| auto fg = node->func_graph(); | |||
| auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); | |||
| return NewCNode({NewValueNode(PrimMul), cnode, ttmul}, fg); | |||
| auto new_mul_tensor = MulConstantTensors(tensor1, tensor2, c_p_node); | |||
| if (new_mul_tensor == nullptr) { | |||
| auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); | |||
| return NewCNode({NewValueNode(PrimMul), c_p_node, ttmul}, fg); | |||
| } | |||
| return NewCNode({NewValueNode(PrimMul), c_p_node, new_mul_tensor}, fg); | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { | |||
| @@ -213,19 +550,40 @@ class ConstantDuplicateMul : public AnfVisitor { | |||
| vnode_ = node; | |||
| } | |||
| if (IsCNode(node)) { | |||
| cnode_ = node->cast<CNodePtr>(); | |||
| if (IsCNode(node) || IsParam(node)) { | |||
| c_p_node_ = node; | |||
| } | |||
| } | |||
| void Reset() { | |||
| vnode_ = nullptr; | |||
| cnode_ = nullptr; | |||
| c_p_node_ = nullptr; | |||
| } | |||
| private: | |||
| AnfNodePtr vnode_; | |||
| CNodePtr cnode_; | |||
| AnfNodePtr c_p_node_; | |||
| }; | |||
| class PowerOneEliminate : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimPow) || node->func_graph() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||
| if (!IsValueNode<Scalar>(inputs[2])) { | |||
| return nullptr; | |||
| } | |||
| auto scalar = GetValueNode<ScalarPtr>(inputs[2]); | |||
| if (scalar->isa<FloatImm>() && GetValue<float>(scalar) == 1.0) { | |||
| return inputs[1]; | |||
| } else if (scalar->isa<IntergerImm>() && GetValue<int>(scalar) == 1) { | |||
| return inputs[1]; | |||
| } | |||
| return nullptr; | |||
| } | |||
| }; | |||
| // grad = AllReduce(grad) / worker_number | |||
| @@ -341,17 +699,21 @@ class ArithmeticSimplify { | |||
| public: | |||
| ArithmeticSimplify() | |||
| : multiply_by_zero_or_one_(), | |||
| tensor_multiply_by_zero_or_one_(), | |||
| add_by_zero_(), | |||
| tensor_add_by_zero_(), | |||
| identity_(prim::kPrimIdentity), | |||
| opt_update_zero_tensor_(), | |||
| constant_duplicate_mul_() { | |||
| constant_duplicate_mul_(), | |||
| power_one_() { | |||
| eliminaters_.emplace_back(multiply_by_zero_or_one_); | |||
| eliminaters_.emplace_back(tensor_multiply_by_zero_or_one_); | |||
| eliminaters_.emplace_back(add_by_zero_); | |||
| eliminaters_.emplace_back(tensor_add_by_zero_); | |||
| eliminaters_.emplace_back(identity_); | |||
| eliminaters_.emplace_back(opt_update_zero_tensor_); | |||
| eliminaters_.emplace_back(constant_duplicate_mul_); | |||
| eliminaters_.emplace_back(power_one_); | |||
| } | |||
| ~ArithmeticSimplify() = default; | |||
| @@ -368,11 +730,13 @@ class ArithmeticSimplify { | |||
| private: | |||
| MultiplyByZeroOrOne multiply_by_zero_or_one_; | |||
| TensorMultiplyByZeroOrOne tensor_multiply_by_zero_or_one_; | |||
| AddByZero add_by_zero_; | |||
| TensorAddByZero tensor_add_by_zero_; | |||
| PrimEliminater identity_; | |||
| OptUpdateZeroTensor opt_update_zero_tensor_; | |||
| ConstantDuplicateMul constant_duplicate_mul_; | |||
| PowerOneEliminate power_one_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| }; | |||
| } // namespace irpass | |||
| @@ -21,6 +21,7 @@ | |||
| #include <algorithm> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <unordered_set> | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| @@ -28,7 +29,6 @@ | |||
| #include "ir/func_graph.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "operator/ops.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| @@ -81,13 +81,32 @@ class IncorporateGetitem : public AnfVisitor { | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| Reset(); | |||
| AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int32Imm>})(node); | |||
| if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr) { | |||
| return nullptr; | |||
| } | |||
| if (node->func_graph() != nullptr && idx_ >= 0 && fg_ != nullptr) { | |||
| auto new_fg = getitem_transform_(fg_, idx_); | |||
| (void)args_.insert(args_.begin(), NewValueNode(new_fg)); | |||
| return node->func_graph()->NewCNode(args_); | |||
| if (fg_->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||
| // If graph kernel has muti output, do not split. | |||
| // some graph kernel output has EnvInstance node or DeadCode node should split. | |||
| auto output = fg_->output(); | |||
| if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { | |||
| auto output_cnode = output->cast<CNodePtr>(); | |||
| auto outputs = output_cnode->inputs(); | |||
| int real_output_cnt = 0; | |||
| for (size_t i = 1; i < outputs.size(); ++i) { | |||
| if (IsCNode(outputs[i]) || IsValueNode<tensor::Tensor>(outputs[i]) || IsParam(outputs[i])) { | |||
| real_output_cnt++; | |||
| if (real_output_cnt > 1) { | |||
| return nullptr; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return nullptr; | |||
| auto new_fg = getitem_transform_(fg_, idx_); | |||
| (void)args_.insert(args_.begin(), NewValueNode(new_fg)); | |||
| return node->func_graph()->NewCNode(args_); | |||
| } | |||
| void Visit(const CNodePtr &cnode) override { | |||
| @@ -115,6 +134,172 @@ class IncorporateGetitem : public AnfVisitor { | |||
| internal::GetitemTransform getitem_transform_; | |||
| }; | |||
| class IncorporateGetitemFromParam : public AnfVisitor { | |||
| public: | |||
| void Process(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr ¶m, size_t input_idx) { | |||
| auto mng = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto &node_users = mng->node_users(); | |||
| if (node_users.find(param) == node_users.end() || node_users[param].empty()) { | |||
| args_.push_back(cnode->input(input_idx + 1)); | |||
| return; | |||
| } | |||
| for (auto &user : node_users[param]) { | |||
| if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) { | |||
| // we do not process this case. | |||
| args_.push_back(cnode->input(input_idx + 1)); | |||
| return; | |||
| } | |||
| } | |||
| // update new args. | |||
| if (IsPrimitiveCNode(cnode->input(input_idx + 1), prim::kPrimMakeTuple)) { | |||
| // case 1 | |||
| replace_parameters_[input_idx] = true; | |||
| need_update_ = true; | |||
| auto make_tuple_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>(); | |||
| auto &make_tuple_cnode_inputs = make_tuple_cnode->inputs(); | |||
| inputs_num_[input_idx] = make_tuple_cnode_inputs.size() - 1; | |||
| args_.insert(args_.end(), make_tuple_cnode_inputs.begin() + 1, make_tuple_cnode_inputs.end()); | |||
| } else { | |||
| // case 2 | |||
| auto prev_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>(); | |||
| auto prev_fg = GetValueNode<FuncGraphPtr>(prev_cnode->input(0)); | |||
| auto fg_output = prev_fg->output(); | |||
| if (!IsPrimitiveCNode(fg_output, prim::kPrimMakeTuple)) { | |||
| MS_LOG(ERROR) << "The return of: " << prev_fg->ToString() | |||
| << " should be a make tuple, but got: " << fg_output->DebugString(); | |||
| return; | |||
| } | |||
| replace_parameters_[input_idx] = true; | |||
| need_update_ = true; | |||
| auto make_tuple_cnode = fg_output->cast<CNodePtr>(); | |||
| inputs_num_[input_idx] = make_tuple_cnode->inputs().size() - 1; | |||
| for (size_t output_i = 0; output_i < inputs_num_[input_idx]; ++output_i) { | |||
| auto new_getitem = | |||
| func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), prev_cnode, NewValueNode(SizeToInt(output_i))}); | |||
| auto aptr = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(SizeToInt(output_i))); | |||
| new_getitem->input(2)->set_abstract(aptr); | |||
| new_getitem->set_abstract(make_tuple_cnode->input(output_i + 1)->abstract()); | |||
| args_.push_back(new_getitem); | |||
| } | |||
| } | |||
| } | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| if (node->func_graph() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| Reset(); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto &inputs = cnode->inputs(); | |||
| auto fg = GetValueNode<FuncGraphPtr>(inputs[0]); | |||
| if (fg == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto mng = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto parameters = fg->parameters(); | |||
| if (parameters.size() != inputs.size() - 1) { | |||
| return nullptr; | |||
| } | |||
| replace_parameters_ = std::vector<bool>(parameters.size(), false); | |||
| inputs_num_ = std::vector<size_t>(parameters.size(), 1); | |||
| auto node_fg = node->func_graph(); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| if (IsPrimitiveCNode(inputs[i], prim::kPrimMakeTuple) || IsCNodeGraphKernel(inputs[i])) { | |||
| Process(node_fg, cnode, parameters[i - 1], i - 1); | |||
| } else { | |||
| args_.push_back(inputs[i]); | |||
| } | |||
| } | |||
| if (!need_update_) { | |||
| return nullptr; | |||
| } | |||
| FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp")); | |||
| mng->AddFuncGraph(new_fg); | |||
| auto node_users = mng->node_users(); | |||
| std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters(); | |||
| std::vector<AnfNodePtr> new_parameters; | |||
| size_t curr_input_idx{0}; | |||
| for (size_t param_i = 0; param_i < new_fg_parameters.size(); ++param_i) { | |||
| if (!replace_parameters_[param_i]) { | |||
| if (parameters[param_i]->abstract() != nullptr) { | |||
| new_fg_parameters[param_i]->set_abstract(parameters[param_i]->abstract()); | |||
| } | |||
| new_parameters.push_back(new_fg_parameters[param_i]); | |||
| curr_input_idx++; | |||
| continue; | |||
| } | |||
| // make a new parameter. | |||
| for (size_t input_i = 0; input_i < inputs_num_[param_i]; ++input_i) { | |||
| auto new_param = std::make_shared<Parameter>(new_fg); | |||
| new_param->set_abstract(args_.at(curr_input_idx)->abstract()); | |||
| // update users of new parameter. | |||
| for (auto &user : node_users[new_fg_parameters[param_i]]) { | |||
| idx_ = -1; | |||
| AnfVisitor::Match(prim::kPrimTupleGetItem, {IsParam, IsValueNode<Int32Imm>})(user.first); | |||
| if (idx_ == -1) { | |||
| MS_LOG(ERROR) << "User of: " << new_fg_parameters[param_i]->DebugString() | |||
| << " must be tuple getitem here, but got: " << user.first->DebugString(); | |||
| return nullptr; | |||
| } | |||
| if (input_i == IntToSize(idx_)) { | |||
| for (auto &sub_user : node_users[user.first]) { | |||
| auto sub_user_cnode = sub_user.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(sub_user_cnode); | |||
| sub_user_cnode->set_input(sub_user.second, new_param); | |||
| (void)mng->Replace(sub_user.first, sub_user_cnode); | |||
| } | |||
| } | |||
| } | |||
| // (void)mng->Replace(new_fg_parameters[param_i], new_param); | |||
| new_parameters.push_back(new_param); | |||
| curr_input_idx++; | |||
| } | |||
| } | |||
| mng->SetParameters(new_fg, new_parameters); | |||
| (void)args_.insert(args_.begin(), NewValueNode(new_fg)); | |||
| auto new_call = node_fg->NewCNode(args_); | |||
| new_call->set_abstract(node->abstract()); | |||
| return new_call; | |||
| } | |||
| void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue<int>(vnode->value()); } | |||
| void Visit(const CNodePtr &cnode) override {} | |||
| void Reset() { | |||
| replace_parameters_.clear(); | |||
| args_.clear(); | |||
| inputs_num_.clear(); | |||
| need_update_ = false; | |||
| idx_ = -1; | |||
| } | |||
| private: | |||
| std::vector<bool> replace_parameters_{}; | |||
| std::vector<AnfNodePtr> args_{}; | |||
| std::vector<size_t> inputs_num_{}; | |||
| bool need_update_{false}; | |||
| int idx_{-1}; | |||
| }; | |||
| // {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C} | |||
| class IncorporateGetitemSwitch : public AnfVisitor { | |||
| public: | |||
| @@ -86,20 +86,10 @@ bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) { | |||
| bool IsInside(FuncGraphPtr, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | |||
| auto &flags = node->func_graph()->flags(); | |||
| if (flags.find("inline_inside") != flags.end()) { | |||
| return flags["inline_inside"]; | |||
| } | |||
| return false; | |||
| return node->func_graph()->has_flag("inline_inside"); | |||
| } | |||
| bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { | |||
| auto &flags = fg->flags(); | |||
| if (flags.find("core") != flags.end()) { | |||
| return flags["core"]; | |||
| } | |||
| return false; | |||
| } | |||
| bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); } | |||
| bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; } | |||
| @@ -123,6 +113,13 @@ class InlinerBase : public AnfVisitor { | |||
| if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) { | |||
| return nullptr; | |||
| } | |||
| // Do not inline GraphKernel to Cell. | |||
| if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||
| // If the GraphKernel only contains a return node, we make it inlined. | |||
| if (fg->nodes().size() - fg->parameters().size() > 1) { | |||
| return nullptr; | |||
| } | |||
| } | |||
| Reset(); | |||
| bool is_match = false; | |||
| @@ -0,0 +1,86 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H | |||
| #include <string> | |||
| #include <sstream> | |||
| #include <unordered_map> | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "ir/visitor.h" | |||
| #include "operator/ops.h" | |||
| #include "utils/graph_utils.h" | |||
| #include "operator/composite/composite.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| static int count = 0; | |||
| std::string GetFusionNumber() { | |||
| std::stringstream ss; | |||
| ss << std::setw(4) << std::setfill('0') << count; | |||
| std::string num = ss.str(); | |||
| ++count; | |||
| return "_" + num; | |||
| } | |||
| // Mark CNodes which can be merged in kernel build | |||
| class MarkInterfaceFusion : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| if (node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsPrimitiveCNode(node, prim::kPrimSelect)) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto condition = cnode->input(1); | |||
| std::string cmp; | |||
| std::unordered_map<std::string, std::string> cmp_list = {{"GreaterEqual", "GE"}, {"Greater", "GT"}, | |||
| {"LessEqual", "LE"}, {"Less", "LT"}, | |||
| {"Equal", "EQ"}, {"NotEqual", "NE"}}; | |||
| if (IsPrimitiveCNode(condition)) { | |||
| auto prim_name = GetCNodeFuncName(condition->cast<CNodePtr>()); | |||
| if (cmp_list.count(prim_name) != 0) { | |||
| // Mark Select and compare node | |||
| cmp = cmp_list[prim_name]; | |||
| auto cnt = GetFusionNumber(); | |||
| AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), condition); | |||
| AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt + "_end"), node); | |||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||
| if (IsPrimitiveCNode(cnode->input(i), prim::kPrimZerosLike)) { | |||
| AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), cnode->input(i)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| void Visit(const AnfNodePtr &) override {} | |||
| private: | |||
| AnfNodePtr y_{nullptr}; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H | |||
| @@ -19,6 +19,7 @@ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| @@ -196,6 +197,131 @@ class AddNZeroFilter : public AnfVisitor { | |||
| std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{}; | |||
| bool has_zero_like_{false}; | |||
| }; | |||
| // {PrimAddN, {kPrimMakeTuple, Xs}} | |||
| // Akg don't support AddN(ValueNode, Tensor, ...), converted to TensorAdd. | |||
| // case0: AddN(inputs)(inputs size < 2) -> error | |||
| // case1: AddN(inputs)(all inputs is ValueNode) -> error | |||
| // case2: AddN(inputs)(inputs size = 2) -> TensorAdd(Tensor, Tensor) | |||
| // case3: AddN(ValueNode, Tensor, Tensor, ...)(has one ValueNode input) | |||
| // -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...)) | |||
| class AddNEliminater : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| if (!node->isa<CNode>() || node->func_graph() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||
| auto fg = GetValueNode<FuncGraphPtr>(inputs[0]); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| auto mng = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| if (fg->recursive()) { | |||
| return nullptr; | |||
| } | |||
| auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg")); | |||
| mng->AddFuncGraph(new_fg); | |||
| need_update_ = false; | |||
| bool changed = false; | |||
| do { | |||
| changed = false; | |||
| changed |= Process(new_fg); | |||
| } while (changed); | |||
| if (!need_update_) { | |||
| return nullptr; | |||
| } else { | |||
| auto new_sx = inputs; | |||
| new_sx[0] = NewValueNode(new_fg); | |||
| return node->func_graph()->NewCNode(new_sx); | |||
| } | |||
| } | |||
| bool Process(const FuncGraphPtr &func_graph) { | |||
| auto mng = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto nodes = TopoSort(func_graph->output()); | |||
| bool changed = false; | |||
| for (size_t i = 0; i < nodes.size(); ++i) { | |||
| auto node = nodes[i]; | |||
| if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto &tuple_input = cnode->input(1); | |||
| MS_EXCEPTION_IF_NULL(tuple_input); | |||
| auto tuple_input_cnode = tuple_input->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_input_cnode); | |||
| auto &tuple_inputs = tuple_input_cnode->inputs(); | |||
| if (tuple_inputs.size() < 3) { | |||
| // case0: inputs size < 2, error | |||
| MS_EXCEPTION(ArgumentError) << "Inputs size of AddN less than 2. " << cnode->DebugString(2); | |||
| } | |||
| int valuenode_num = | |||
| std::accumulate(tuple_inputs.begin() + 1, tuple_inputs.end(), 0, [](int accumulator, const AnfNodePtr &node) { | |||
| if (IsValueNode<tensor::Tensor>(node)) { | |||
| return accumulator + 1; | |||
| } else { | |||
| return accumulator; | |||
| } | |||
| }); | |||
| if (IntToSize(valuenode_num) == tuple_inputs.size()) { | |||
| // case1: all inputs is ValueNode, error | |||
| MS_EXCEPTION(ArgumentError) << "All inputs of AddN is ValueNode. " << cnode->DebugString(2); | |||
| } | |||
| if (tuple_inputs.size() == 3) { | |||
| // case2: inputs size = 2, -> TensorAdd(Tensor, Tensor) | |||
| MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2); | |||
| ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations"); | |||
| std::vector<AnfNodePtr> new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1], | |||
| tuple_inputs[2]}; | |||
| mng->Replace(node, func_graph->NewCNode(new_xs)); | |||
| changed = true; | |||
| continue; | |||
| } | |||
| auto first_valuenode = std::find_if(tuple_inputs.begin() + 1, tuple_inputs.end(), | |||
| [](const AnfNodePtr &node) { return IsValueNode<tensor::Tensor>(node); }); | |||
| if (first_valuenode == tuple_inputs.end()) { | |||
| // no ValueNode input found. | |||
| continue; | |||
| } else { | |||
| // case3: has one ValueNode input -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...)) | |||
| std::vector<AnfNodePtr> make_tuple_new_xs{ | |||
| NewValueNode(prim::kPrimMakeTuple), | |||
| }; | |||
| std::for_each(tuple_inputs.begin() + 1, tuple_inputs.end(), | |||
| [&make_tuple_new_xs, &first_valuenode](const AnfNodePtr &node) { | |||
| if (node != *first_valuenode) { | |||
| make_tuple_new_xs.push_back(node); | |||
| } | |||
| }); | |||
| ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations"); | |||
| auto new_addn = func_graph->NewCNode( | |||
| {func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)}); | |||
| ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations"); | |||
| auto new_add = | |||
| func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn}); | |||
| (void)mng->Replace(node, new_add); | |||
| changed = true; | |||
| continue; | |||
| } | |||
| } | |||
| need_update_ |= changed; | |||
| return changed; | |||
| } | |||
| private: | |||
| bool need_update_{false}; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -79,7 +79,7 @@ class ReduceOneEliminater : public AnfVisitor { | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { | |||
| if (x_ == nullptr) { | |||
| if (!IsVNode(node) && x_ == nullptr) { | |||
| if (IsValueNode<tensor::Tensor>(node)) { | |||
| is_tensor_ = true; | |||
| } | |||
| @@ -23,6 +23,8 @@ | |||
| #include "optimizer/irpass.h" | |||
| #include "ir/visitor.h" | |||
| #include "operator/ops.h" | |||
| #include "utils/graph_utils.h" | |||
| #include "operator/composite/composite.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -36,6 +38,7 @@ class MakeRefEliminater : public AnfVisitor { | |||
| this->y_ = node; | |||
| return true; | |||
| }; | |||
| AnfVisitor::Match(prim::kPrimMakeRef, {IsNode, gety, IsNode})(node); | |||
| return y_; | |||
| } | |||
| @@ -142,7 +142,7 @@ class ResetDeferInline : public AnfVisitor { | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| if (IsValueNode<FuncGraph>(node)) { | |||
| auto fg = GetValueNode<FuncGraphPtr>(node); | |||
| fg->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false); | |||
| fg->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false); | |||
| } | |||
| return nullptr; | |||
| } | |||
| @@ -22,6 +22,7 @@ | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| @@ -41,7 +42,7 @@ class SpecializeTransform { | |||
| ~SpecializeTransform() = default; | |||
| FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector<FuncGraphPtr> graph_args, | |||
| std::vector<PrimitivePtr> prim_args) { | |||
| std::vector<PrimitivePtr> prim_args, std::vector<tensor::TensorPtr> value_args) { | |||
| if (cache_.count(func_graph) == 0) { | |||
| cache_[func_graph] = {}; | |||
| } | |||
| @@ -69,6 +70,13 @@ class SpecializeTransform { | |||
| (void)mng->Replace(params[i], arg); | |||
| continue; | |||
| } | |||
| if (value_args[i] != nullptr) { | |||
| auto const_tensor = *value_args[i]; | |||
| auto const_tensor_ptr = std::make_shared<tensor::Tensor>(const_tensor); | |||
| AnfNodePtr arg = NewValueNode(const_tensor_ptr); | |||
| (void)mng->Replace(params[i], arg); | |||
| continue; | |||
| } | |||
| new_params.push_back(params[i]); | |||
| } | |||
| @@ -108,6 +116,7 @@ class SpecializeOnGraphArguments : public AnfVisitor { | |||
| std::vector<FuncGraphPtr> graph_args; | |||
| std::vector<PrimitivePtr> prim_args; | |||
| std::vector<tensor::TensorPtr> value_node_args; | |||
| std::vector<AnfNodePtr> new_xs; | |||
| bool hasVNode = false; | |||
| for (size_t i = 1; i < inputs.size(); i++) { | |||
| @@ -115,15 +124,24 @@ class SpecializeOnGraphArguments : public AnfVisitor { | |||
| auto fg_vnode = GetValueNode<FuncGraphPtr>(inputs[i]); | |||
| graph_args.push_back(fg_vnode); | |||
| prim_args.emplace_back(nullptr); | |||
| value_node_args.emplace_back(nullptr); | |||
| hasVNode = true; | |||
| } else if (IsValueNode<Primitive>(inputs[i])) { | |||
| auto p_vnode = GetValueNode<PrimitivePtr>(inputs[i]); | |||
| graph_args.emplace_back(nullptr); | |||
| prim_args.push_back(p_vnode); | |||
| value_node_args.emplace_back(nullptr); | |||
| hasVNode = true; | |||
| } else if (IsValueNode<tensor::Tensor>(inputs[i])) { | |||
| tensor::TensorPtr t_vnode = GetValueNode<tensor::TensorPtr>(inputs[i]); | |||
| graph_args.emplace_back(nullptr); | |||
| prim_args.emplace_back(nullptr); | |||
| value_node_args.emplace_back(t_vnode); | |||
| hasVNode = true; | |||
| } else { | |||
| graph_args.emplace_back(nullptr); | |||
| prim_args.emplace_back(nullptr); | |||
| value_node_args.emplace_back(nullptr); | |||
| new_xs.push_back(inputs[i]); | |||
| } | |||
| } | |||
| @@ -132,7 +150,7 @@ class SpecializeOnGraphArguments : public AnfVisitor { | |||
| return nullptr; | |||
| } | |||
| auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args); | |||
| auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, value_node_args); | |||
| (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); | |||
| return node->func_graph()->NewCNode(new_xs); | |||
| @@ -141,6 +159,146 @@ class SpecializeOnGraphArguments : public AnfVisitor { | |||
| private: | |||
| internal::SpecializeTransform specialize_transform_; | |||
| }; | |||
| // Eliminate unused parameters. | |||
| // {G, Xs} | |||
| class UnusedParasEliminater : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| if (!node->isa<CNode>() || node->func_graph() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto &inputs = cnode->inputs(); | |||
| auto fg = GetValueNode<FuncGraphPtr>(inputs[0]); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| std::vector<AnfNodePtr> parameters = fg->parameters(); | |||
| size_t size = parameters.size(); | |||
| if (size != inputs.size() - 1) { | |||
| return nullptr; | |||
| } | |||
| std::vector<AnfNodePtr> new_xs; | |||
| std::vector<bool> keep_parameters; | |||
| auto mng = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto &node_users = mng->node_users(); | |||
| bool has_unused_para = false; | |||
| for (size_t i = 0; i < size; ++i) { | |||
| auto iter = node_users.find(parameters[i]); | |||
| if (iter != node_users.end() && !iter->second.empty()) { | |||
| keep_parameters.push_back(true); | |||
| new_xs.push_back(inputs[i + 1]); | |||
| continue; | |||
| } | |||
| keep_parameters.push_back(false); | |||
| has_unused_para = true; | |||
| } | |||
| if (!has_unused_para) { | |||
| return nullptr; | |||
| } | |||
| FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp")); | |||
| mng->AddFuncGraph(new_fg); | |||
| std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters(); | |||
| std::vector<AnfNodePtr> new_parameters; | |||
| for (size_t i = 0; i < size; i++) { | |||
| if (keep_parameters[i]) { | |||
| if (parameters[i]->abstract() != nullptr) { | |||
| new_fg_parameters[i]->set_abstract(parameters[i]->abstract()); | |||
| } | |||
| new_parameters.push_back(new_fg_parameters[i]); | |||
| } | |||
| } | |||
| mng->SetParameters(new_fg, new_parameters); | |||
| (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); | |||
| return node->func_graph()->NewCNode(new_xs); | |||
| } | |||
| }; | |||
| // Eliminate unused outputs. | |||
| // {G, Xs} | |||
| class UnusedOutputEliminater : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| if (!node->isa<CNode>() || node->func_graph() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||
| auto fg = GetValueNode<FuncGraphPtr>(inputs[0]); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| auto mng = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| if (fg->recursive()) { | |||
| return nullptr; | |||
| } | |||
| auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg")); | |||
| mng->AddFuncGraph(new_fg); | |||
| auto new_fg_output = new_fg->output(); | |||
| if (!IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) { | |||
| return nullptr; | |||
| } | |||
| auto output_cnode = new_fg_output->cast<CNodePtr>(); | |||
| auto &node_users = mng->node_users(); | |||
| if (node_users.count(node) == 0 || node_users[node].empty()) { | |||
| return nullptr; | |||
| } | |||
| std::unordered_set<int> used_output_idx; | |||
| std::vector<std::pair<AnfNodePtr, int>> all_users; | |||
| for (auto &node_user : node_users[node]) { | |||
| if (!IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) { | |||
| return nullptr; | |||
| } | |||
| auto user_cnode = node_user.first->cast<CNodePtr>(); | |||
| size_t used_idx = GetValue<int>(user_cnode->input(2)->cast<ValueNodePtr>()->value()); | |||
| used_output_idx.insert(used_idx); | |||
| all_users.push_back(std::make_pair(node_user.first, used_idx)); | |||
| } | |||
| if (used_output_idx.size() >= output_cnode->inputs().size() - 1) { | |||
| // all output has users. | |||
| return nullptr; | |||
| } | |||
| if (used_output_idx.empty()) { | |||
| // we do not process this case. | |||
| return nullptr; | |||
| } else if (used_output_idx.size() == 1) { | |||
| // after eliminate, only one output left. | |||
| new_fg->set_output(output_cnode->input(*used_output_idx.begin() + 1)); | |||
| // update users. | |||
| for (auto &ret_user : all_users) { | |||
| (void)mng->Replace(ret_user.first, node); | |||
| } | |||
| } else { | |||
| // after eliminate, create new multi output. | |||
| std::vector<AnfNodePtr> new_output_inputs{output_cnode->input(0)}; | |||
| std::unordered_map<int, int> new_idx_map; | |||
| for (auto idx : used_output_idx) { | |||
| new_idx_map[idx] = SizeToInt(new_output_inputs.size() - 1); | |||
| new_output_inputs.push_back(output_cnode->input(idx + 1)); | |||
| } | |||
| new_fg->set_output(new_fg->NewCNode(new_output_inputs)); | |||
| // update users. | |||
| for (auto &ret_user : all_users) { | |||
| auto ret_user_cnode = ret_user.first->cast<CNodePtr>(); | |||
| ret_user_cnode->set_input(2, NewValueNode(new_idx_map[ret_user.second])); | |||
| } | |||
| } | |||
| auto new_sx = inputs; | |||
| new_sx[0] = NewValueNode(new_fg); | |||
| return node->func_graph()->NewCNode(new_sx); | |||
| } | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -89,7 +89,7 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>; | |||
| class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| public: | |||
| Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) | |||
| : name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false) {} | |||
| : name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false), is_enable_(true) {} | |||
| virtual ~Optimizer() = default; | |||
| void Init(const OptPassGroupMap &passes, bool run_only_once) { | |||
| @@ -132,6 +132,9 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| } | |||
| FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { | |||
| if (!is_enable_) { | |||
| return func_graph; | |||
| } | |||
| // Optimizer step counter; | |||
| int counter = -1; | |||
| bool changes = true; | |||
| @@ -171,7 +174,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| }; | |||
| use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func(); | |||
| if (IS_OUTPUT_ON(mindspore::DEBUG) && MsContext::GetInstance()->save_graphs_flag()) { | |||
| MS_LOG(DEBUG) << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; | |||
| MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; | |||
| auto fg_name = | |||
| "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; | |||
| func_graph->DumpFuncGraph(fg_name); | |||
| @@ -211,6 +214,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| void enable_watch_renormalize() { is_watch_renormalize_ = true; } | |||
| void disable_watch_renormalize() { is_watch_renormalize_ = false; } | |||
| bool is_watch_renormalize() { return is_watch_renormalize_; } | |||
| void set_enable(bool enable) { is_enable_ = enable; } | |||
| private: | |||
| const std::string name_; | |||
| @@ -220,6 +224,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| bool run_only_once_; | |||
| std::vector<AnfNodePtr> untyped_nodes_; | |||
| bool is_watch_renormalize_; | |||
| bool is_enable_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -64,7 +64,7 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti | |||
| DumpGraph(root, std::string(ALLREDUCE_FUSION_END)); | |||
| // allreduce fusion only run once | |||
| root->flags()[ALLREDUCE_FUSION_RUN_ONCE_ONLY] = true; | |||
| root->set_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY, true); | |||
| res->results()[pipeline::kStepParallelGraph] = root; | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| auto end_time = std::chrono::steady_clock::now(); | |||
| @@ -158,8 +158,8 @@ void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(param_node); | |||
| MS_EXCEPTION_IF_NULL(ptr); | |||
| if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->flags().count(TRAINING) == 0) || | |||
| func_graph->flags()[TRAINING]) { | |||
| if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->attrs().count(TRAINING) == 0) || | |||
| func_graph->has_flag(TRAINING)) { | |||
| return; | |||
| } | |||
| @@ -107,7 +107,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | |||
| time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec); | |||
| MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us"; | |||
| root->flags()[AUTO_PARALLEL_RUN_ONCE_ONLY] = true; | |||
| root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true); | |||
| return changes; | |||
| } | |||
| @@ -2270,10 +2270,10 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) | |||
| (root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) { | |||
| if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) { | |||
| if (HasStrategy(root)) { | |||
| MS_LOG(INFO) << "strategies ignored in " << parallel_mode | |||
| MS_LOG(INFO) << "Strategies ignored in " << parallel_mode | |||
| << ", set_strategy() only valid in [semi_]auto_parallel."; | |||
| } | |||
| root->flags()[CHECK_SET_STRATEGY_VALID_ONCE_ONLY] = true; | |||
| root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true); | |||
| } | |||
| return changes; | |||
| @@ -2330,11 +2330,11 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) | |||
| DumpGraph(root, std::string(STEP_PARALLEL_END)); | |||
| // step parallel only run once | |||
| root->flags()[SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY] = true; | |||
| root->set_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY, true); | |||
| res->results()[pipeline::kStepParallelGraph] = root; | |||
| // in auto parallel mode, no need to check if stategies set | |||
| root->flags()[CHECK_SET_STRATEGY_VALID_ONCE_ONLY] = true; | |||
| root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true); | |||
| (void)gettimeofday(&end_time, nullptr); | |||
| uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec); | |||
| @@ -151,7 +151,10 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| .def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.") | |||
| .def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.") | |||
| .def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.") | |||
| .def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print."); | |||
| .def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.") | |||
| .def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel, | |||
| "Set the GraphKernel switch to on or off.") | |||
| .def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch."); | |||
| (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") | |||
| .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") | |||
| @@ -278,7 +278,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { | |||
| if (bprop_graph != nullptr) { | |||
| (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); | |||
| (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); | |||
| func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true); | |||
| func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true); | |||
| } | |||
| } | |||
| *data = func_graph; | |||
| @@ -1448,15 +1448,23 @@ bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) { | |||
| } | |||
| py::dict flags = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG); | |||
| for (auto &item : flags) { | |||
| if (!py::isinstance<py::str>(item.first) || !py::isinstance<py::bool_>(item.second)) { | |||
| if (!py::isinstance<py::str>(item.first)) { | |||
| MS_LOG(ERROR) << "Type error in flags dict convert"; | |||
| return false; | |||
| } | |||
| auto name = py::cast<std::string>(item.first); | |||
| auto value = py::cast<bool>(item.second); | |||
| MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; | |||
| func_graph->set_flags(name, value); | |||
| if (py::isinstance<py::bool_>(item.second)) { | |||
| auto value = py::cast<bool>(item.second); | |||
| MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; | |||
| func_graph->set_flag(name, value); | |||
| } else if (py::isinstance<py::str>(item.second)) { | |||
| auto value = py::cast<std::string>(item.second); | |||
| MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; | |||
| func_graph->set_attr(name, MakeValue(value)); | |||
| } else { | |||
| MS_LOG(ERROR) << "Type error in flags/attrs dict convert"; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| @@ -223,8 +223,8 @@ class Parser { | |||
| FunctionBlockPtr block = std::make_shared<FunctionBlock>(parse); | |||
| // In order to keep effect order in the sub-graphs which generated by control flow. | |||
| // We copy the flags from the top graph to the sub-graphs. | |||
| if (func_graph_ && !func_graph_->flags().empty()) { | |||
| block->func_graph()->set_flags(func_graph_->flags()); | |||
| if (func_graph_ && !func_graph_->attrs().empty()) { | |||
| block->func_graph()->set_attrs(func_graph_->attrs()); | |||
| } | |||
| func_block_list_.push_back(block); | |||
| return block; | |||
| @@ -25,12 +25,14 @@ | |||
| #include <functional> | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "debug/anf_ir_utils.h" | |||
| #include "pipeline/parse/parse_base.h" | |||
| #include "pipeline/parse/data_converter.h" | |||
| #include "pipeline/resource.h" | |||
| #include "pipeline/validator.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "optimizer/cse.h" | |||
| #include "optimizer/graph_kernel_reuse.h" | |||
| #include "optimizer/clean.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/control_depend.h" | |||
| @@ -38,6 +40,7 @@ | |||
| #include "parallel/step_auto_parallel.h" | |||
| #include "parallel/allreduce_fusion/step_allreduce_fusion.h" | |||
| #include "utils/any.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace pipeline { | |||
| @@ -162,6 +165,40 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| return map; | |||
| } | |||
| OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| opt::OptPassConfig interface_fusion = opt::OptPassConfig({ | |||
| irpass.mark_interface_fusion_, | |||
| }); | |||
| OptPassGroupMap map({ | |||
| {"graph_kernel_reuse", opt::OptPassConfig(opt::GraphKernelReuse())}, | |||
| {"interface_fusion", interface_fusion}, | |||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | |||
| {"cse", opt::OptPassConfig(opt::CSE(false))}, | |||
| }); | |||
| return map; | |||
| } | |||
| OptPassGroupMap GetOptPassesGraphKernelB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| opt::OptPassConfig elim_1 = opt::OptPassConfig({ | |||
| irpass.addn_eliminate_, | |||
| irpass.incorporate_getitem_from_param_, | |||
| }); | |||
| opt::OptPassConfig elim_2 = opt::OptPassConfig({ | |||
| irpass.unused_parameter_eliminate_, | |||
| irpass.unused_output_eliminate_, | |||
| }); | |||
| OptPassGroupMap map({ | |||
| {"elim_1", elim_1}, | |||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | |||
| {"elim_2", elim_2}, | |||
| }); | |||
| return map; | |||
| } | |||
| OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}}); | |||
| } | |||
| OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}, true); | |||
| OptPassGroupMap map({ | |||
| @@ -191,8 +228,19 @@ void InitOpt(const ResourcePtr &res) { | |||
| opt::irpass::OptimizeIRPassLib irpass; | |||
| g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); | |||
| g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); | |||
| g_pass_opts["opt_graph_kernel_a"] = | |||
| Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true); | |||
| g_pass_opts["opt_graph_kernel_b"] = | |||
| Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false); | |||
| g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass)); | |||
| g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true); | |||
| g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!(context_ptr->enable_graph_kernel())) { | |||
| g_pass_opts["opt_graph_kernel_a"]->set_enable(false); | |||
| g_pass_opts["opt_graph_kernel_b"]->set_enable(false); | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| @@ -224,9 +272,13 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) { | |||
| bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } | |||
| bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } | |||
| bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); } | |||
| bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); } | |||
| bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } | |||
| bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } | |||
| bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); } | |||
| bool AddControlDependPass(const ResourcePtr &res) { | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -270,8 +322,10 @@ bool InferenceOptPreparePass(const ResourcePtr &res) { | |||
| std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, | |||
| {"opt_a", OptPassAGroup}, | |||
| {"opt_b", OptPassBGroup}, | |||
| {"add_control_depend", AddControlDependPass}, | |||
| {"cconv", CconvPass}}; | |||
| {"cconv", CconvPass}, | |||
| {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, | |||
| {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, | |||
| {"add_control_depend", AddControlDependPass}}; | |||
| std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, | |||
| {"opt_a", OptPassAGroup}, | |||
| @@ -488,7 +488,7 @@ py::object ExecDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const | |||
| #ifdef ENABLE_INFER | |||
| // Now don't use the graph because the exec ge function don't take effect | |||
| MS_EXCEPTION_IF_NULL(info.at(phase)->func_graph); | |||
| if (ENABLE_TRAIN != info.at(phase)->func_graph->flags()["training"]) { | |||
| if (ENABLE_TRAIN != info.at(phase)->func_graph->has_flag("training")) { | |||
| MS_LOG(ERROR) << "Graph training mode mismatch mode of libraries"; | |||
| ConfigManager::GetInstance().ResetConfig(); | |||
| return py::none(); | |||
| @@ -165,7 +165,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa | |||
| MS_LOG(DEBUG) << "Joined args: " << ::mindspore::ToString(joined_args_spec_list); | |||
| // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. | |||
| if (!(joined_args_spec_list == args_spec_list)) { | |||
| func_graph_->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| } | |||
| return joined_args_spec_list; | |||
| } | |||
| @@ -178,7 +178,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa | |||
| // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. | |||
| if (!(joined_args_spec_list == args_spec_list)) { | |||
| trace_.push_back(joined_args_spec_list); | |||
| func_graph_->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| } | |||
| MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); | |||
| return joined_args_spec_list; | |||
| @@ -479,7 +479,7 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) { | |||
| if (undetermined_fgs) { | |||
| auto fg_parent = fg->parent(); | |||
| MS_EXCEPTION_IF_NULL(fg_parent); | |||
| fg_parent->set_flags(kFuncGraphFlagUndetermined, true); | |||
| fg_parent->set_flag(kFuncGraphFlagUndetermined, true); | |||
| MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); | |||
| } | |||
| } | |||
| @@ -16,6 +16,7 @@ | |||
| #include "pre_activate/ascend/ascend_backend_optimization.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <set> | |||
| #include "pre_activate/common/optimizer.h" | |||
| #include "pre_activate/ascend/ir_fission/bn_split.h" | |||
| #include "pre_activate/ascend/ir_fission/bn_grad_split.h" | |||
| @@ -63,6 +64,9 @@ | |||
| #include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" | |||
| #include "pre_activate/pass/eliminate_redundant_op.h" | |||
| #include "pre_activate/pass/common_subexpression_elimination.h" | |||
| #include "pre_activate/pass/fuse_graph_kernel.h" | |||
| #include "pre_activate/pass/fuse_basic.h" | |||
| #include "pre_activate/pass/add_atomic_clean.h" | |||
| #include "pre_activate/ascend/format_type/merge_cast_to_op.h" | |||
| #include "pre_activate/ascend/format_type/check_consistency.h" | |||
| #include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h" | |||
| @@ -88,6 +92,8 @@ | |||
| #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" | |||
| #include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h" | |||
| #include "pre_activate/ascend/ir_fission/split_fission.h" | |||
| #include "pre_activate/ascend/format_type/modify_ops_attrs.h" | |||
| #include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "utils/config_manager.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| @@ -164,6 +170,19 @@ void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_g | |||
| kernel_graph->SetExecOrderByDefault(); | |||
| } | |||
| void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||
| MS_EXCEPTION_IF_NULL(optimizer); | |||
| auto common_process = std::make_shared<PassManager>("graph_kernel_common_process"); | |||
| MS_EXCEPTION_IF_NULL(common_process); | |||
| common_process->AddPass(std::make_shared<ModifyOpAttrs>()); | |||
| common_process->AddPass(std::make_shared<RemoveNoUseReshapeOp>()); | |||
| optimizer->AddPassManager(common_process); | |||
| (void)optimizer->Optimize(kernel_graph); | |||
| kernel_graph->SetExecOrderByDefault(); | |||
| } | |||
| void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||
| @@ -332,7 +351,94 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern | |||
| std::string file_path = | |||
| save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | |||
| DumpIR(file_path, kernel_graph, true); | |||
| DumpIRProto(kernel_graph, "after_hwopt_" + std::to_string(kernel_graph->graph_id())); | |||
| DumpIRProto(kernel_graph, "after_hwopt"); | |||
| kernel_graph->DumpFuncGraph("hwopt_d_end"); | |||
| } | |||
| } | |||
| void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph, | |||
| bool is_before_kernel_select) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!(context_ptr->enable_graph_kernel())) { | |||
| return; | |||
| } | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| if (save_graphs) { | |||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_before_graph_" + | |||
| std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + | |||
| ".ir"; | |||
| DumpIR(file_path, kernel_graph); | |||
| } | |||
| // Fuse graph kernels with basic ops | |||
| FuseGraphKernel(kernel_graph, is_before_kernel_select); | |||
| if (save_graphs) { | |||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_end_graph_" + | |||
| std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + | |||
| ".ir"; | |||
| DumpIR(file_path, kernel_graph, true); | |||
| } | |||
| } | |||
| void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph, | |||
| bool is_before_kernel_select) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!(context_ptr->enable_graph_kernel())) { | |||
| return; | |||
| } | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| if (save_graphs) { | |||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_before_graph_" + | |||
| std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + | |||
| ".ir"; | |||
| DumpIR(file_path, kernel_graph, true); | |||
| } | |||
| // Fuse basic ops with basic ops | |||
| FuseBasic(kernel_graph, is_before_kernel_select); | |||
| if (save_graphs) { | |||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_end_graph_" + | |||
| std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + | |||
| ".ir"; | |||
| DumpIR(file_path, kernel_graph, true); | |||
| } | |||
| } | |||
| void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (!(context_ptr->enable_graph_kernel())) { | |||
| return; | |||
| } | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| if (save_graphs) { | |||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_add_atomic_clean_before" + "_graph_" + | |||
| std::to_string(kernel_graph->graph_id()) + ".ir"; | |||
| DumpIR(file_path, kernel_graph); | |||
| } | |||
| AddAtomicClean(kernel_graph); | |||
| if (save_graphs) { | |||
| std::string file_path = | |||
| save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | |||
| DumpIR(file_path, kernel_graph, true); | |||
| } | |||
| } | |||
| @@ -24,6 +24,12 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne | |||
| void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||
| void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||
| void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||
| void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||
| void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph, | |||
| bool is_before_kernel_select = false); | |||
| void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph, | |||
| bool is_before_kernel_select = false); | |||
| void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||
| void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||
| void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||
| } // namespace opt | |||
| @@ -22,6 +22,7 @@ | |||
| #include "utils/utils.h" | |||
| #include "device/kernel_info.h" | |||
| #include "kernel/oplib/oplib.h" | |||
| #include "kernel/common_utils.h" | |||
| #include "operator/ops.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "session/kernel_graph.h" | |||
| @@ -229,7 +230,7 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr | |||
| if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kTBE) != nullptr) { | |||
| builder.SetKernelType(KernelType::TBE_KERNEL); | |||
| } else { | |||
| builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); | |||
| builder.SetKernelType(KernelType::AKG_KERNEL); | |||
| } | |||
| // if kernel info is null , it remarks this function is running ut | |||
| if (cast->kernel_info() == nullptr) { | |||
| @@ -284,22 +285,17 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | |||
| TypeId origin_type; | |||
| const auto infer_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); | |||
| TypeId origin_type(kTypeUnknown); | |||
| auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); | |||
| auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); | |||
| auto is_weight_boundary = [](const AnfNodePtr &node) -> bool { | |||
| if (node->isa<ValueNode>()) { | |||
| return true; | |||
| } | |||
| if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) { | |||
| return true; | |||
| } | |||
| return false; | |||
| }; | |||
| auto real_input_node = kernel_with_index.first; | |||
| if (is_weight_boundary(real_input_node)) { | |||
| if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||
| // weight | |||
| origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index); | |||
| origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index); | |||
| if (origin_type == kTypeUnknown) { | |||
| origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index); | |||
| } | |||
| } else { | |||
| // feature map | |||
| origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); | |||
| @@ -307,9 +303,13 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod | |||
| const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); | |||
| const std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_index); | |||
| const TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); | |||
| if (origin_type != device_type) { | |||
| // In graph kernel, we check parameter, | |||
| // the eliminate pass will not eliminate this case, so we just do not insert the noused cast. | |||
| if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) { | |||
| new_inputs.push_back(cur_input); | |||
| } else if (origin_type != device_type) { | |||
| auto cast = | |||
| AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, origin_type); | |||
| AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type); | |||
| MS_EXCEPTION_IF_NULL(cast); | |||
| cast->set_scope(cnode->scope()); | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast); | |||
| @@ -17,9 +17,12 @@ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "utils/utils.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "common/utils.h" | |||
| #include "kernel/common_utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -74,11 +77,21 @@ const AnfNodePtr CheckConsistency::Process(const FuncGraphPtr &, const AnfNodePt | |||
| if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) { | |||
| return nullptr; | |||
| } | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) { | |||
| if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) { | |||
| MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(node) << "[" | |||
| << node->DebugString() << "]"; | |||
| std::vector<AnfNodePtr> todos = {node}; | |||
| if (AnfAlgo::IsGraphKernel(node)) { | |||
| auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||
| MS_EXCEPTION_IF_NULL(sub_graph); | |||
| kernel::GetValidKernelNodes(sub_graph, &todos); | |||
| } | |||
| for (auto &t : todos) { | |||
| CNodePtr cnode = t->cast<CNodePtr>(); | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) { | |||
| if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) { | |||
| MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "[" | |||
| << cnode->DebugString() << "]"; | |||
| } | |||
| } | |||
| } | |||
| return nullptr; | |||
| @@ -18,6 +18,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "device/kernel_info.h" | |||
| #include "pre_activate/ascend/ascend_helper.h" | |||
| @@ -27,34 +28,45 @@ | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "session/kernel_graph.h" | |||
| #include "utils/utils.h" | |||
| #include "kernel/common_utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<bool> &need_insert_cast) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||
| AbstractBasePtrList abstract_list; | |||
| make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(cnode); ++output_idx) { | |||
| const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx); | |||
| const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); | |||
| const TypeId origin_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx); | |||
| const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx); | |||
| AnfNodePtr replace_node = nullptr; | |||
| const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); | |||
| const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx); | |||
| auto idx = NewValueNode(SizeToInt(output_idx)); | |||
| MS_EXCEPTION_IF_NULL(idx); | |||
| auto imm = std::make_shared<Int32Imm>(output_idx); | |||
| idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm)); | |||
| auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx}); | |||
| AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, getitem.get()); | |||
| AnfNodePtr replace_node = nullptr; | |||
| if (origin_type != device_type) { | |||
| replace_node = | |||
| AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape, origin_type); | |||
| MS_EXCEPTION_IF_NULL(replace_node); | |||
| replace_node->set_scope(cnode->scope()); | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); | |||
| AnfAlgo::SetOutputInferTypeAndShape({infer_type}, {origin_shape}, getitem.get()); | |||
| if (need_insert_cast[output_idx]) { | |||
| const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx); | |||
| TypeId origin_type(kTypeUnknown); | |||
| if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||
| origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode); | |||
| } | |||
| origin_type = origin_type == kTypeUnknown ? infer_type : origin_type; | |||
| const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx); | |||
| if (origin_type != device_type) { | |||
| replace_node = | |||
| AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape, infer_type); | |||
| MS_EXCEPTION_IF_NULL(replace_node); | |||
| replace_node->set_scope(cnode->scope()); | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); | |||
| } else { | |||
| replace_node = getitem; | |||
| } | |||
| } else { | |||
| replace_node = getitem; | |||
| } | |||
| @@ -65,9 +77,10 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||
| return make_tuple; | |||
| } | |||
| } // namespace | |||
| AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<bool> &need_insert_cast) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (AnfAlgo::GetOutputTensorNum(cnode) == 0) { | |||
| @@ -76,14 +89,23 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c | |||
| MS_EXCEPTION_IF_NULL(cnode->Type()); | |||
| // Single output | |||
| if (!cnode->Type()->isa<Tuple>()) { | |||
| if (!need_insert_cast[0]) { | |||
| return cnode; | |||
| } | |||
| const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0); | |||
| std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0); | |||
| const TypeId origin_type = AnfAlgo::GetOutputInferDataType(cnode, 0); | |||
| const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, 0); | |||
| TypeId origin_type(kTypeUnknown); | |||
| if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||
| origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode); | |||
| } | |||
| origin_type = origin_type == kTypeUnknown ? infer_type : origin_type; | |||
| const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0); | |||
| AnfNodePtr replace_node = cnode; | |||
| if (origin_type != device_type) { | |||
| replace_node = | |||
| AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape, origin_type); | |||
| AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape, infer_type); | |||
| MS_EXCEPTION_IF_NULL(replace_node); | |||
| replace_node->set_scope(cnode->scope()); | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); | |||
| @@ -91,7 +113,57 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c | |||
| return replace_node; | |||
| } | |||
| // Multiple output | |||
| return InsertCastForMultipleOutput(func_graph, cnode); | |||
| return InsertCastForMultipleOutput(func_graph, cnode, need_insert_cast); | |||
| } | |||
| AnfNodePtr ProcessGraphKernelOp(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| // insert cast for ops in graph kernel. | |||
| auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||
| MS_EXCEPTION_IF_NULL(sub_graph); | |||
| auto mng = sub_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| std::vector<AnfNodePtr> todo; | |||
| std::vector<std::pair<AnfNodePtr, size_t>> graph_rets; | |||
| kernel::GetValidKernelNodes(sub_graph, &todo); | |||
| kernel::GetGraphRealOutput(sub_graph, &graph_rets); | |||
| for (auto &t : todo) { | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), t); | |||
| // process input | |||
| CNodePtr t_cnode = t->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(t_cnode); | |||
| auto t_new_node = InsertCastForInput(sub_graph, t_cnode); | |||
| AnfNodePtr t_new_node_1 = nullptr; | |||
| std::vector<bool> need_insert_cast(AnfAlgo::GetOutputTensorNum(t), true); | |||
| // process output | |||
| auto iter = std::find_if(graph_rets.begin(), graph_rets.end(), | |||
| [&t](const std::pair<AnfNodePtr, size_t> &ret) { return ret.first == t; }); | |||
| if (iter != graph_rets.end()) { | |||
| auto t_fix_output_type = AnfAlgo::GetCNodeOutputPrecision(t); | |||
| auto t_output_type = AnfAlgo::GetOutputDeviceDataType(t, iter->second); | |||
| auto graph_output_type = AnfAlgo::GetOutputDeviceDataType(node, iter - graph_rets.begin()); | |||
| if (t_fix_output_type == kTypeUnknown && t_output_type == graph_output_type) { | |||
| need_insert_cast[iter->second] = false; | |||
| } else if (t_fix_output_type == t_output_type && t_output_type == graph_output_type) { | |||
| need_insert_cast[iter->second] = false; | |||
| } | |||
| t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast); | |||
| } else { | |||
| t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast); | |||
| } | |||
| if (t_new_node_1 != nullptr && t_new_node_1 != t) { | |||
| (void)mng->Replace(t, t_new_node_1); | |||
| } | |||
| } | |||
| // insert cast for graph kernel. | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||
| // process input | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto new_node = InsertCastForInput(func_graph, cnode); | |||
| // process output | |||
| return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true)); | |||
| } | |||
| } // namespace | |||
| @@ -106,13 +178,27 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo | |||
| if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) { | |||
| return nullptr; | |||
| } | |||
| if (AnfAlgo::IsGraphKernel(node)) { | |||
| return ProcessGraphKernelOp(func_graph, node); | |||
| } else { | |||
| // insert cast for single op. | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||
| // process input | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto new_node = InsertCastForInput(func_graph, cnode); | |||
| // process output | |||
| return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true)); | |||
| } | |||
| // insert cast for single op. | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||
| // process input | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto new_node = InsertCastForInput(func_graph, cnode); | |||
| // process output | |||
| return InsertCastForOutput(func_graph, new_node); | |||
| return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -133,6 +133,9 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co | |||
| return nullptr; | |||
| } | |||
| auto next_cnode = next_node->cast<CNodePtr>(); | |||
| if (AnfAlgo::IsGraphKernel(next_node)) { | |||
| return nullptr; | |||
| } | |||
| auto next_op_name = AnfAlgo::GetCNodeName(next_node); | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | |||
| kernel_query->Query(next_cnode, &kernel_info_list); | |||
| @@ -206,6 +209,9 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod | |||
| return nullptr; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(prior_op); | |||
| if (AnfAlgo::IsGraphKernel(prior_op)) { | |||
| return nullptr; | |||
| } | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | |||
| kernel_query->Query(prior_op, &kernel_info_list); | |||
| @@ -0,0 +1,99 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/format_type/modify_ops_attrs.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "utils/utils.h" | |||
| #include "pre_activate/common/helper.h" | |||
| #include "kernel/common_utils.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "operator/ops.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| AnfNodePtr ModifyReduceOpsAttrs(const CNodePtr &cnode) { | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); | |||
| auto input_format = AnfAlgo::GetInputFormat(cnode, 0); | |||
| if (input_shape.size() == 5 || input_format != kOpFormat_NC1HWC0) { | |||
| return nullptr; | |||
| } | |||
| if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) { | |||
| return nullptr; | |||
| } | |||
| AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), cnode); | |||
| return cnode; | |||
| } | |||
| AnfNodePtr ModifyTileOpAttrs(const CNodePtr &cnode) { | |||
| auto input_shape = AnfAlgo::GetInputDeviceShape(cnode, 0); | |||
| if (input_shape.size() != 5) { | |||
| return nullptr; | |||
| } | |||
| if (!AnfAlgo::HasNodeAttr(kAttrMultiples, cnode)) { | |||
| return nullptr; | |||
| } | |||
| auto multiples = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrMultiples); | |||
| if (multiples.size() == 4 && multiples[1] == 1) { | |||
| multiples.push_back(1); | |||
| AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), cnode); | |||
| } | |||
| return cnode; | |||
| } | |||
| AnfNodePtr ModifyAttrs(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||
| if (op_name == prim::kPrimTile->name()) { | |||
| return ModifyTileOpAttrs(cnode); | |||
| } else if (op_name == prim::kPrimReduceSum->name()) { | |||
| // kPrimReduceMean | |||
| // kPrimReduceSum | |||
| // kPrimReduceAll | |||
| // kPrimReduceMax | |||
| // kPrimReduceMin | |||
| return ModifyReduceOpsAttrs(cnode); | |||
| } | |||
| return nullptr; | |||
| } | |||
| } // namespace | |||
| const AnfNodePtr ModifyOpAttrs::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsGraphKernel(node)) { | |||
| return nullptr; | |||
| } | |||
| MS_LOG(DEBUG) << "====Process op: " << AnfAlgo::GetCNodeName(node); | |||
| auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| auto manager = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| std::vector<AnfNodePtr> todos; | |||
| kernel::GetValidKernelNodes(fg, &todos); | |||
| for (auto &t : todos) { | |||
| auto new_node = ModifyAttrs(t->cast<CNodePtr>()); | |||
| if (new_node != nullptr && new_node != t) { | |||
| (void)manager->Replace(t, new_node); | |||
| } | |||
| } | |||
| return node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H | |||
| #include "pre_activate/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class ModifyOpAttrs : public PatternProcessPass { | |||
| public: | |||
| explicit ModifyOpAttrs(bool multigraph = true) : PatternProcessPass("modify_ops_attrs", multigraph) {} | |||
| ~ModifyOpAttrs() override = default; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H | |||
| @@ -0,0 +1,66 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "pre_activate/common/helper.h" | |||
| #include "kernel/common_utils.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "operator/ops.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| AnfNodePtr RemoveReshapeOp(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||
| if (op_name != prim::kPrimReshape->name()) { | |||
| return nullptr; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); | |||
| auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0); | |||
| if (input_shape.size() != 1 || input_format != kOpFormat_NC1HWC0) { | |||
| return nullptr; | |||
| } | |||
| return cnode->input(1); | |||
| } | |||
| } // namespace | |||
| const AnfNodePtr RemoveNoUseReshapeOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsGraphKernel(node)) { | |||
| return nullptr; | |||
| } | |||
| MS_LOG(DEBUG) << "====process op: " << AnfAlgo::GetCNodeName(node); | |||
| auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| auto manager = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| std::vector<AnfNodePtr> todos; | |||
| kernel::GetValidKernelNodes(fg, &todos); | |||
| for (auto &t : todos) { | |||
| auto new_node = RemoveReshapeOp(t->cast<CNodePtr>()); | |||
| if (new_node != nullptr && new_node != t) { | |||
| (void)manager->Replace(t, new_node); | |||
| } | |||
| } | |||
| return node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H | |||
| #include "pre_activate/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class RemoveNoUseReshapeOp : public PatternProcessPass { | |||
| public: | |||
| explicit RemoveNoUseReshapeOp(bool multigraph = true) : PatternProcessPass("remove_no_use_reshape_op", multigraph) {} | |||
| ~RemoveNoUseReshapeOp() override = default; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H | |||
| @@ -121,6 +121,9 @@ const AnfNodePtr LayerNormBetaGammaBackpropFusion::Process(const FuncGraphPtr &f | |||
| if (node == nullptr || !node->isa<CNode>()) { | |||
| return nullptr; | |||
| } | |||
| if (AnfAlgo::IsGraphKernel(node)) { | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<CNodePtr> cast_nodes; | |||
| @@ -102,9 +102,12 @@ bool UnVisited(const BaseRef &n) { | |||
| auto prim_py = value->cast<PrimitivePtr>(); | |||
| MS_EXCEPTION_IF_NULL(prim_py); | |||
| return !prim_py->HasAttr(kAttrVisited); | |||
| } else { | |||
| return false; | |||
| } else if (IsValueNode<FuncGraph>(in)) { | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(in); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| return !func_graph->has_flag(kAttrVisited); | |||
| } | |||
| return false; | |||
| } | |||
| return false; | |||
| } | |||
| @@ -188,9 +191,12 @@ bool Visited(const BaseRef &n) { | |||
| auto prim_py = value->cast<PrimitivePtr>(); | |||
| MS_EXCEPTION_IF_NULL(prim_py); | |||
| return prim_py->HasAttr(kAttrVisited); | |||
| } else { | |||
| return false; | |||
| } else if (IsValueNode<FuncGraph>(in)) { | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(in); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| return func_graph->has_flag(kAttrVisited); | |||
| } | |||
| return false; | |||
| } | |||
| return false; | |||
| } | |||