Merge pull request !2160 from xianwz/master_graph_kerneltags/v0.5.0-beta
| @@ -13,3 +13,6 @@ | |||||
| [submodule "graphengine"] | [submodule "graphengine"] | ||||
| path = graphengine | path = graphengine | ||||
| url = https://gitee.com/mindspore/graphengine.git | url = https://gitee.com/mindspore/graphengine.git | ||||
| [submodule "akg"] | |||||
| path = akg | |||||
| url = https://gitee.com/mindspore/akg.git | |||||
| @@ -86,10 +86,14 @@ if (ENABLE_GE OR ENABLE_D OR ENABLE_TESTCASES) | |||||
| include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain) | include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain) | ||||
| endif() | endif() | ||||
| if (ENABLE_AKG AND ENABLE_D) | |||||
| add_subdirectory("${CMAKE_SOURCE_DIR}/akg") | |||||
| endif() | |||||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") | ||||
| add_subdirectory(mindspore/ccsrc) | add_subdirectory(mindspore/ccsrc) | ||||
| if (ENABLE_TESTCASES) | if (ENABLE_TESTCASES) | ||||
| add_subdirectory(tests) | add_subdirectory(tests) | ||||
| endif() | endif() | ||||
| include(cmake/package.cmake) | |||||
| include(cmake/package.cmake) | |||||
| @@ -0,0 +1 @@ | |||||
| Subproject commit c460176523d039c8995f1d71089753725ebc0792 | |||||
| @@ -246,6 +246,9 @@ checkopts "$@" | |||||
| echo "---------------- mindspore: build start ----------------" | echo "---------------- mindspore: build start ----------------" | ||||
| mkdir -pv "${BUILD_PATH}/package/mindspore/lib" | mkdir -pv "${BUILD_PATH}/package/mindspore/lib" | ||||
| git submodule update --init graphengine | git submodule update --init graphengine | ||||
| if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then | |||||
| git submodule update --init --recursive akg | |||||
| fi | |||||
| build_exit() | build_exit() | ||||
| { | { | ||||
| @@ -308,7 +311,7 @@ build_mindspore() | |||||
| if [[ "X$USE_GLOG" = "Xon" ]]; then | if [[ "X$USE_GLOG" = "Xon" ]]; then | ||||
| CMAKE_ARGS="${CMAKE_ARGS} -DUSE_GLOG=ON" | CMAKE_ARGS="${CMAKE_ARGS} -DUSE_GLOG=ON" | ||||
| fi | fi | ||||
| if [[ "X$ENABLE_AKG" = "Xon" ]]; then | |||||
| if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then | |||||
| CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON" | CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON" | ||||
| fi | fi | ||||
| echo "${CMAKE_ARGS}" | echo "${CMAKE_ARGS}" | ||||
| @@ -236,6 +236,16 @@ if (ENABLE_GPU) | |||||
| endif () | endif () | ||||
| endif () | endif () | ||||
| if (ENABLE_D AND ENABLE_AKG) | |||||
| set (AKG_PATH ${CMAKE_SOURCE_DIR}/build/mindspore/akg) | |||||
| install( | |||||
| DIRECTORY | |||||
| ${AKG_PATH}/akg | |||||
| DESTINATION ${INSTALL_PY_DIR}/.. | |||||
| COMPONENT mindspore | |||||
| ) | |||||
| endif () | |||||
| if (EXISTS ${CMAKE_SOURCE_DIR}/mindspore/dataset) | if (EXISTS ${CMAKE_SOURCE_DIR}/mindspore/dataset) | ||||
| install( | install( | ||||
| DIRECTORY ${CMAKE_SOURCE_DIR}/mindspore/dataset | DIRECTORY ${CMAKE_SOURCE_DIR}/mindspore/dataset | ||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,35 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Providing akg compile with json""" | |||||
| import sys | |||||
| def run_compiler(op_json): | |||||
| """ | |||||
| Run AKG compiler to compile op with subprocess, if this process of | |||||
| compilation failed, an exception will be raised | |||||
| Args: | |||||
| op_json (str): json string of the op | |||||
| Returns: | |||||
| None | |||||
| """ | |||||
| p = __import__("akg", globals(), locals(), ['ms'], 0) | |||||
| func = getattr(p.ms, "compilewithjson") | |||||
| res = func(op_json) | |||||
| if not res: | |||||
| raise ValueError("Compile error") | |||||
| if __name__ == "__main__": | |||||
| run_compiler(sys.argv[1]) | |||||
| @@ -0,0 +1,71 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Providing multi process compile with json""" | |||||
| import os | |||||
| import subprocess | |||||
| import sys | |||||
| from multiprocessing import Pool, cpu_count | |||||
| def _compile_akg_task(*json_strs): | |||||
| """ | |||||
| compile func called in single process | |||||
| Parameters: | |||||
| json_strs: list. List contains multiple kernel infos, suitable for json compile api. | |||||
| """ | |||||
| akg_compiler = os.path.join(os.path.split( | |||||
| os.path.realpath(__file__))[0], "compiler.py") | |||||
| for json_str in json_strs: | |||||
| res = subprocess.run( | |||||
| [sys.executable, akg_compiler, json_str], text=True) | |||||
| if res.returncode != 0: | |||||
| raise ValueError("Failed, args: {}!".format(json_str)) | |||||
| def compile_akg_kernel_parallel(json_infos, process, waitime): | |||||
| """ | |||||
| compile kernel use multi processes | |||||
| Parameters: | |||||
| json_infos: list. list contain kernel info(task id and json str) | |||||
| process: int. processes num | |||||
| waittime: int. max time the function blocked | |||||
| Returns: | |||||
| True for all compile success, False for some failed. | |||||
| """ | |||||
| if not isinstance(json_infos, list): | |||||
| raise ValueError("json_infos must be a list") | |||||
| if not isinstance(process, int): | |||||
| raise ValueError("process must be a num") | |||||
| if not isinstance(waitime, int): | |||||
| raise ValueError("waittime must be a num") | |||||
| if process == 0 and json_infos: | |||||
| process = 1 | |||||
| cpu_proc_num = cpu_count() | |||||
| max_proc_num = 16 | |||||
| process = min([cpu_proc_num, max_proc_num, process]) | |||||
| args = [[] for _ in range(process)] | |||||
| for p, info in enumerate(json_infos): | |||||
| args[p % process].append(info) | |||||
| with Pool(processes=process) as pool: | |||||
| res = pool.starmap_async(_compile_akg_task, args) | |||||
| res.get(timeout=waitime) | |||||
| return True | |||||
| @@ -1,107 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Providing multi process compile with json""" | |||||
| import json | |||||
| import math | |||||
| import os | |||||
| import subprocess | |||||
| import sys | |||||
| from multiprocessing import Pool | |||||
| def _compiletask(platform, *jsons): | |||||
| """ | |||||
| compile func called in single process | |||||
| Parameters: | |||||
| platform: str. AKG platform or TBE platform | |||||
| *jsons: str. json str contain kernel info, suitable for json compile | |||||
| api | |||||
| """ | |||||
| if platform == "AKG": | |||||
| p = __import__("_akg", globals(), locals(), ['ms'], 0) | |||||
| func = getattr(p.ms, "compilewithjson") | |||||
| for json_item in jsons: | |||||
| res = func(json_item) | |||||
| if not res: | |||||
| raise ValueError("Compile error") | |||||
| if platform == "TBE": | |||||
| tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "tbe_compiler", "compiler.py") | |||||
| for json_item in jsons: | |||||
| res = subprocess.run([sys.executable, tbe_compiler], input=json_item, text=True) | |||||
| if res.returncode != 0: | |||||
| raise ValueError("Tbe compile error") | |||||
| def compilekernelparallel(jsons, process, waitime): | |||||
| """ | |||||
| compile kernel use multi processes | |||||
| Parameters: | |||||
| jsons: list. json str list contain kernel info | |||||
| process: int. processes num | |||||
| waittime: int. max time the function blocked | |||||
| """ | |||||
| if not isinstance(jsons, list): | |||||
| raise ValueError("jsons must be a list") | |||||
| if not isinstance(process, int): | |||||
| raise ValueError("process must be a num") | |||||
| if not isinstance(waitime, int): | |||||
| raise ValueError("waittime must be a num") | |||||
| jsons_akg = [] | |||||
| jsons_tbe = [] | |||||
| for json_ in jsons: | |||||
| j = json.loads(json_) | |||||
| if j["platform"] == "TBE": | |||||
| jsons_tbe.append(json_) | |||||
| continue | |||||
| if j["platform"] == "AKG": | |||||
| jsons_akg.append(json_) | |||||
| continue | |||||
| raise RuntimeError( | |||||
| "not support this platform {0}".format(j["platform"])) | |||||
| if jsons_akg: | |||||
| process_akg = math.floor(len(jsons)/len(jsons_akg)*process) | |||||
| else: | |||||
| process_akg = 0 | |||||
| if process_akg == 0 and jsons_akg: | |||||
| process_akg = 1 | |||||
| process_tbe = process-process_akg | |||||
| if process_tbe == 0 and jsons_tbe: | |||||
| process_tbe = 1 | |||||
| raise RuntimeWarning("we add a process for compile more operator") | |||||
| args = [[] for _ in range(process_akg+process_tbe)] | |||||
| args_lens = len(args) | |||||
| for p in range(args_lens): | |||||
| if p < process_tbe: | |||||
| args[p].append("TBE") | |||||
| else: | |||||
| args[p].append("AKG") | |||||
| jsons_tbe_lens = len(jsons_tbe) | |||||
| for p in range(jsons_tbe_lens): | |||||
| args[p % process_tbe].append(jsons_tbe[p]) | |||||
| jsons_akg_lens = len(jsons_akg) | |||||
| for p in range(jsons_akg_lens): | |||||
| args[process-p % process_akg-1].append(jsons_akg[p]) | |||||
| for p in range(args_lens): | |||||
| args[p] = tuple(args[p]) | |||||
| with Pool(processes=process) as pool: | |||||
| res = pool.starmap_async(_compiletask, args) | |||||
| res.get(timeout=waitime) | |||||
| return True | |||||
| @@ -39,7 +39,7 @@ if(ENABLE_GPU) | |||||
| "device/gpu/*.cu" | "device/gpu/*.cu" | ||||
| "kernel/gpu/*.cu" | "kernel/gpu/*.cu" | ||||
| "kernel/akg/gpu/*.cc" | "kernel/akg/gpu/*.cc" | ||||
| "kernel/akg/akgkernelbuild.cc" | |||||
| "kernel/akg/akg_kernel_build.cc" | |||||
| "kernel/akg/akg_kernel_attrs_process.cc" | "kernel/akg/akg_kernel_attrs_process.cc" | ||||
| ) | ) | ||||
| @@ -428,6 +428,10 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s | |||||
| auto temp_shape = shape; | auto temp_shape = shape; | ||||
| std::vector<size_t> device_shape; | std::vector<size_t> device_shape; | ||||
| if (format == kOpFormat_FRAC_NZ) { | if (format == kOpFormat_FRAC_NZ) { | ||||
| if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) { | |||||
| // For [1] and [1024] shape we can trait it as NZ shape | |||||
| return shape; | |||||
| } | |||||
| if (shape.size() < 2) { | if (shape.size() < 2) { | ||||
| MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size(); | MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size(); | ||||
| } else { | } else { | ||||
| @@ -111,9 +111,15 @@ void DumpGlobalInfoEntry(const FuncGraphPtr &graph, std::ostringstream &buffer) | |||||
| } | } | ||||
| buffer << "#IR entry : @" << graph->ToString() << "." << graph->debug_info()->get_id() << std::endl; | buffer << "#IR entry : @" << graph->ToString() << "." << graph->debug_info()->get_id() << std::endl; | ||||
| buffer << "#flags :" << std::endl; | |||||
| for (const auto &flag : graph->flags()) { | |||||
| buffer << flag.first << " : " << flag.second << std::endl; | |||||
| buffer << "#attrs :" << std::endl; | |||||
| for (const auto &attr : graph->attrs()) { | |||||
| buffer << attr.first << " : "; | |||||
| if (attr.second->isa<BoolImm>()) { | |||||
| buffer << GetValue<bool>(attr.second); | |||||
| } else if (attr.second->isa<StringImm>()) { | |||||
| buffer << GetValue<std::string>(attr.second); | |||||
| } | |||||
| buffer << std::endl; | |||||
| } | } | ||||
| } | } | ||||
| @@ -417,10 +423,16 @@ void DumpSubgraph(const OrderedMap<FuncGraphPtr, std::shared_ptr<SubGraphIRInfo> | |||||
| fout << std::endl; | fout << std::endl; | ||||
| for (const auto &sg : *sub_graphs) { | for (const auto &sg : *sub_graphs) { | ||||
| fout << "subgraph flag:" << std::endl; | |||||
| fout << "subgraph attr:" << std::endl; | |||||
| MS_EXCEPTION_IF_NULL(sg.first); | MS_EXCEPTION_IF_NULL(sg.first); | ||||
| for (const auto &flag : sg.first->flags()) { | |||||
| fout << flag.first << " : " << flag.second << std::endl; | |||||
| for (const auto &attr : sg.first->attrs()) { | |||||
| fout << attr.first << " : "; | |||||
| if (attr.second->isa<BoolImm>()) { | |||||
| fout << GetValue<bool>(attr.second); | |||||
| } else if (attr.second->isa<StringImm>()) { | |||||
| fout << GetValue<std::string>(attr.second); | |||||
| } | |||||
| fout << std::endl; | |||||
| } | } | ||||
| fout << "subgraph @" << sg.first->ToString() << "."; | fout << "subgraph @" << sg.first->ToString() << "."; | ||||
| fout << sg.first->debug_info()->get_id() << "("; | fout << sg.first->debug_info()->get_id() << "("; | ||||
| @@ -548,9 +548,15 @@ void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr<session::KernelGr | |||||
| for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { | for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { | ||||
| cur_cnode_ptr = cnode_ptr_list[i]; | cur_cnode_ptr = cnode_ptr_list[i]; | ||||
| MS_EXCEPTION_IF_NULL(cur_cnode_ptr); | MS_EXCEPTION_IF_NULL(cur_cnode_ptr); | ||||
| ValuePtr value_ptr = nullptr; | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); | auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); | ||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); | |||||
| if (primitive != nullptr) { | |||||
| value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); | |||||
| } else { | |||||
| auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cur_cnode_ptr); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| value_ptr = func_graph->get_attr(kStreamNeedActivedFirst); | |||||
| } | |||||
| if (value_ptr == nullptr) { | if (value_ptr == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -26,10 +26,12 @@ | |||||
| #include "kernel/kernel.h" | #include "kernel/kernel.h" | ||||
| #include "kernel/tbe/tbe_kernel_build.h" | #include "kernel/tbe/tbe_kernel_build.h" | ||||
| #include "kernel/tbe/tbe_kernel_parallel_build.h" | #include "kernel/tbe/tbe_kernel_parallel_build.h" | ||||
| #include "kernel/akg/ascend/akg_ascend_kernel_build.h" | |||||
| #include "kernel/aicpu/aicpu_kernel_build.h" | #include "kernel/aicpu/aicpu_kernel_build.h" | ||||
| #include "kernel/hccl/hccl_kernel_build.h" | #include "kernel/hccl/hccl_kernel_build.h" | ||||
| #include "kernel/rts/rt_kernel_build.h" | #include "kernel/rts/rt_kernel_build.h" | ||||
| #include "kernel/tbe/tbe_utils.h" | #include "kernel/tbe/tbe_utils.h" | ||||
| #include "kernel/common_utils.h" | |||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "./common.h" | #include "./common.h" | ||||
| @@ -91,6 +93,7 @@ static bool KernelPreBuildParallelCompile(const mindspore::session::KernelGraph | |||||
| static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) { | static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph_ptr); | MS_EXCEPTION_IF_NULL(kernel_graph_ptr); | ||||
| std::vector<AnfNodePtr> tbe_nodes; | std::vector<AnfNodePtr> tbe_nodes; | ||||
| std::vector<AnfNodePtr> akg_nodes; | |||||
| std::vector<AnfNodePtr> other_nodes; | std::vector<AnfNodePtr> other_nodes; | ||||
| for (const auto &anf_node : kernel_graph_ptr->execution_order()) { | for (const auto &anf_node : kernel_graph_ptr->execution_order()) { | ||||
| MS_EXCEPTION_IF_NULL(anf_node); | MS_EXCEPTION_IF_NULL(anf_node); | ||||
| @@ -105,19 +108,26 @@ static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *ke | |||||
| } | } | ||||
| break; | break; | ||||
| } | } | ||||
| case KernelType::AKG_KERNEL: { | |||||
| akg_nodes.push_back(anf_node); | |||||
| break; | |||||
| } | |||||
| default: { | default: { | ||||
| other_nodes.push_back(anf_node); | other_nodes.push_back(anf_node); | ||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| bool ret = kernel::TbeOpParallelBuild(tbe_nodes); | |||||
| bool tbe_ret = kernel::TbeOpParallelBuild(tbe_nodes); | |||||
| bool akg_ret = kernel::AkgAscendKernelParallelBuild(akg_nodes); | |||||
| auto bin_map = kernel::tbe::KernelMeta::GetInstance(); | |||||
| (void)bin_map->ReadIndex(kernel::kCceKernelMeta); | |||||
| for (const auto &anf_node : other_nodes) { | for (const auto &anf_node : other_nodes) { | ||||
| kernel::KernelModPtr kernel_mod_ptr = SerialCompileImpl(anf_node); | kernel::KernelModPtr kernel_mod_ptr = SerialCompileImpl(anf_node); | ||||
| MS_EXCEPTION_IF_NULL(kernel_mod_ptr); | MS_EXCEPTION_IF_NULL(kernel_mod_ptr); | ||||
| AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); | AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); | ||||
| } | } | ||||
| return ret; | |||||
| return tbe_ret && akg_ret; | |||||
| } | } | ||||
| static std::vector<int> CalCleanZerosSize(const CNodePtr &pre_node) { | static std::vector<int> CalCleanZerosSize(const CNodePtr &pre_node) { | ||||
| @@ -234,7 +244,7 @@ void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) { | |||||
| for (const auto &anf_node : kernel_graph->execution_order()) { | for (const auto &anf_node : kernel_graph->execution_order()) { | ||||
| std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); | std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); | ||||
| if (apply_function_name == prim::kPrimMaxPoolGrad->name() && | if (apply_function_name == prim::kPrimMaxPoolGrad->name() && | ||||
| AnfAlgo::GetKernelType(anf_node) == KernelType::AUTO_DIFF_KERNEL) { | |||||
| AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { | |||||
| auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName); | auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName); | ||||
| MS_EXCEPTION_IF_NULL(clear_zero_prim); | MS_EXCEPTION_IF_NULL(clear_zero_prim); | ||||
| auto new_value_node = NewValueNode(clear_zero_prim); | auto new_value_node = NewValueNode(clear_zero_prim); | ||||
| @@ -15,16 +15,27 @@ | |||||
| */ | */ | ||||
| #include "device/ascend/kernel_select_ascend.h" | #include "device/ascend/kernel_select_ascend.h" | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <utility> | #include <utility> | ||||
| #include <algorithm> | |||||
| #include <map> | #include <map> | ||||
| #include "kernel/oplib/oplib.h" | |||||
| #include "kernel/kernel_query.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "utils/context/ms_context.h" | |||||
| #include <unordered_map> | |||||
| #include <unordered_set> | |||||
| #include "common/utils.h" | |||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| #include "operator/ops.h" | |||||
| #include "ir/func_graph.h" | |||||
| #include "utils/context/ms_context.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "device/kernel_info.h" | |||||
| #include "kernel/common_utils.h" | |||||
| #include "kernel/kernel_query.h" | |||||
| #include "kernel/oplib/oplib.h" | |||||
| #include "kernel/kernel_build_info.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| @@ -121,12 +132,23 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons | |||||
| } | } | ||||
| auto pri_match_format = GetPriorityMatchFormat(kernel_node); | auto pri_match_format = GetPriorityMatchFormat(kernel_node); | ||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | ||||
| auto input_anf_node = kernel_node->input(input_index + 1); | |||||
| // we do not take ValueNode into consideration in graph kernel. | |||||
| if (kernel_build_info.kernel_type() == KernelType::AKG_KERNEL) { | |||||
| if (input_anf_node->isa<ValueNode>() && AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) { | |||||
| continue; | |||||
| } | |||||
| } | |||||
| auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore; | auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore; | ||||
| if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { | if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { | ||||
| (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score; | (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score; | ||||
| } | } | ||||
| if (kernel_build_info.GetInputDeviceType(input_index) == | |||||
| AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)) { | |||||
| // we match output fix precision first. | |||||
| auto prev_device_type = AnfAlgo::GetPrevNodeOutputPrecision(kernel_node, input_index); | |||||
| if (prev_device_type == kTypeUnknown) { | |||||
| prev_device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index); | |||||
| } | |||||
| if (kernel_build_info.GetInputDeviceType(input_index) == prev_device_type) { | |||||
| (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score; | (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score; | ||||
| } | } | ||||
| if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { | if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { | ||||
| @@ -146,41 +168,6 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons | |||||
| } | } | ||||
| } | } | ||||
| void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||||
| auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); | |||||
| MS_EXCEPTION_IF_NULL(input_kernel_node); | |||||
| auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); | |||||
| MS_EXCEPTION_IF_NULL(input_with_index.first); | |||||
| auto real_input_node = input_with_index.first; | |||||
| if (real_input_node->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder = | |||||
| std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| bool is_ref = false; | |||||
| auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE); | |||||
| if (op_info != nullptr) { | |||||
| is_ref = op_info->is_ref(); | |||||
| } | |||||
| auto ms_context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(ms_context); | |||||
| if (ms_context->execution_mode() == kPynativeMode && | |||||
| AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { | |||||
| continue; | |||||
| } | |||||
| // we set special device info of a input tensor. | |||||
| if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { | |||||
| std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)}; | |||||
| builder->SetOutputsFormat(output_format); | |||||
| std::vector<TypeId> output_type = {AnfAlgo::GetInputDeviceDataType(kernel_node, input_index)}; | |||||
| builder->SetOutputsDeviceType(output_type); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); | |||||
| } | |||||
| } | |||||
| } | |||||
| void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *support_index) { | void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *support_index) { | ||||
| MS_EXCEPTION_IF_NULL(support_index); | MS_EXCEPTION_IF_NULL(support_index); | ||||
| int index = kUnSupportMixedDataTypeIndex; | int index = kUnSupportMixedDataTypeIndex; | ||||
| @@ -467,6 +454,51 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||||
| auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); | |||||
| MS_EXCEPTION_IF_NULL(input_kernel_node); | |||||
| auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); | |||||
| MS_EXCEPTION_IF_NULL(input_with_index.first); | |||||
| auto real_input_node = input_with_index.first; | |||||
| if (real_input_node->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) { | |||||
| continue; | |||||
| } | |||||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| if (IsValueNode<tensor::Tensor>(input_kernel_node) && | |||||
| AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { | |||||
| std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)}; | |||||
| builder->SetOutputsFormat(output_format); | |||||
| std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; | |||||
| builder->SetOutputsDeviceType(output_type); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); | |||||
| continue; | |||||
| } | |||||
| // we set special device info of a input tensor. | |||||
| bool is_ref = false; | |||||
| auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE); | |||||
| if (op_info != nullptr) { | |||||
| is_ref = op_info->is_ref(); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | |||||
| if (MsContext::GetInstance()->execution_mode() == kPynativeMode && | |||||
| AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { | |||||
| continue; | |||||
| } | |||||
| if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { | |||||
| std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)}; | |||||
| builder->SetOutputsFormat(output_format); | |||||
| std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; | |||||
| builder->SetOutputsDeviceType(output_type); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); | |||||
| } | |||||
| } | |||||
| } | |||||
| KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, | KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, | ||||
| const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) { | const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| @@ -498,11 +530,17 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, | |||||
| return select_status; | return select_status; | ||||
| } | } | ||||
| KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { | |||||
| KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | ||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list; | std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list; | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| kernel::KernelQuery(kernel_node, &kernel_info_list); | |||||
| if (AnfAlgo::IsGraphKernel(kernel_node)) { | |||||
| auto func_graph = GetValueNode<FuncGraphPtr>(kernel_node->input(kAnfPrimitiveIndex)); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| SelectGraphKernelInfo(kernel_node, func_graph); | |||||
| return kStatusAllMatched; | |||||
| } | |||||
| kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type); | |||||
| auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); | auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); | ||||
| // If aicore not find valid kernel info reloading aicpu kernel info list to find it | // If aicore not find valid kernel info reloading aicpu kernel info list to find it | ||||
| if (select_status == kNoMatched) { | if (select_status == kNoMatched) { | ||||
| @@ -27,7 +27,10 @@ enum KernelSelectStatus { | |||||
| kStatusReducePrecision = 1, | kStatusReducePrecision = 1, | ||||
| kStatusRaisePrecision = 2, | kStatusRaisePrecision = 2, | ||||
| }; | }; | ||||
| KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node); | |||||
| KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, | |||||
| KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); | |||||
| void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node); | |||||
| void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph); | |||||
| } // namespace ascend | } // namespace ascend | ||||
| } // namespace device | } // namespace device | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,516 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "device/ascend/kernel_select_ascend.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "device/kernel_info.h" | |||||
| #include "ir/func_graph.h" | |||||
| #include "kernel/common_utils.h" | |||||
| #include "kernel/kernel_query.h" | |||||
| #include "kernel/kernel_build_info.h" | |||||
| namespace mindspore { | |||||
| namespace device { | |||||
| namespace ascend { | |||||
| TypeId GetPrimitivePrecision(const CNodePtr &cnode) { | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| TypeId except_type = kTypeUnknown; | |||||
| if (primitive->GetAttr(kAttrFixPrecision) != nullptr) { | |||||
| auto strExceptDtype = GetValue<std::string>(primitive->GetAttr(kAttrFixPrecision)); | |||||
| if (strExceptDtype == "float16") { | |||||
| except_type = kNumberTypeFloat16; | |||||
| } else if (strExceptDtype == "float32") { | |||||
| except_type = kNumberTypeFloat32; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "The fix precision must be float16 or float32, but got" << strExceptDtype; | |||||
| } | |||||
| } | |||||
| return except_type; | |||||
| } | |||||
| void ResetKernelBuildInfo(const CNodePtr &kernel_node) { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||||
| auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); | |||||
| MS_EXCEPTION_IF_NULL(input_kernel_node); | |||||
| auto kernel_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); | |||||
| if (!kernel::IsWeightBoundary(kernel_with_index.first)) { | |||||
| continue; | |||||
| } | |||||
| // reset format and dtype. | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| builder.SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT}); | |||||
| builder.SetOutputsDeviceType(std::vector<TypeId>{kTypeUnknown}); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_kernel_node.get()); | |||||
| } | |||||
| } | |||||
| void UpdateKernelInfo(const std::vector<AnfNodePtr> &node_list) { | |||||
| for (size_t i = 0; i < node_list.size(); ++i) { | |||||
| // select nodes in subgraph. | |||||
| auto anf_node = node_list[i]; | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| auto cnode = anf_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto fix_precision_type = GetPrimitivePrecision(cnode); | |||||
| if (fix_precision_type != kTypeUnknown) { | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | |||||
| kernel::KernelQuery(cnode, &kernel_info_list, KernelType::AKG_KERNEL); | |||||
| for (size_t index = 0; index < kernel_info_list.size(); ++index) | |||||
| // only math the first input | |||||
| if (kernel_info_list[index]->GetInputDeviceType(0) == fix_precision_type && | |||||
| kernel_info_list[index]->GetInputFormat(0) == AnfAlgo::GetPrevNodeOutputFormat(cnode, 0) && | |||||
| AnfAlgo::GetInputDeviceDataType(cnode, 0) != fix_precision_type) { | |||||
| auto selected_kernel_info_ptr = kernel_info_list[index]; | |||||
| ResetKernelBuildInfo(cnode); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, cnode.get()); | |||||
| SetTensorDeviceInfo(*selected_kernel_info_ptr, cnode); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| bool CanConvertDefaultShapeToNZ(const std::vector<size_t> &shape) { | |||||
| for (size_t i = 1; i <= shape.size(); ++i) { | |||||
| if (i > 2) { | |||||
| break; | |||||
| } | |||||
| if (shape[shape.size() - i] != 1 && shape[shape.size() - i] % kCubeSize != 0) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| std::vector<int> DefaultToFracNZAxis(const std::vector<size_t> &ori_shape, const std::vector<int> &axis) { | |||||
| std::vector<int> frac_nz_axis = axis; | |||||
| auto shape_len = ori_shape.size(); | |||||
| for (size_t i = 0; i < axis.size(); ++i) { | |||||
| auto axis_idx = (frac_nz_axis[i] + shape_len) % shape_len; | |||||
| if (axis_idx == shape_len - 1) { | |||||
| frac_nz_axis[i] = axis_idx - 1; | |||||
| frac_nz_axis.push_back(axis_idx + 2); | |||||
| } else if (axis_idx == shape_len - 2) { | |||||
| frac_nz_axis[i] = axis_idx + 1; | |||||
| frac_nz_axis.push_back(axis_idx + 2); | |||||
| } else { | |||||
| frac_nz_axis[i] = axis_idx; | |||||
| } | |||||
| } | |||||
| return frac_nz_axis; | |||||
| } | |||||
| std::vector<size_t> GetReducedFracNZShape(const std::vector<size_t> &ori_shape, const std::vector<int> &axis, | |||||
| bool keep_dims) { | |||||
| std::vector<size_t> result; | |||||
| std::set<size_t> positive_idx; | |||||
| for (const auto &a : axis) { | |||||
| positive_idx.insert(a >= 0 ? a : ori_shape.size() + a); | |||||
| } | |||||
| for (size_t i = 0; i < ori_shape.size(); ++i) { | |||||
| if (positive_idx.count(i) == 0) { | |||||
| result.push_back(ori_shape[i]); | |||||
| } else if (keep_dims) { | |||||
| result.push_back(1); | |||||
| } | |||||
| } | |||||
| return result; | |||||
| } | |||||
| void UpdateFracNZReduceOp(const CNodePtr &cnode) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0); | |||||
| if (input_format == kOpFormat_FRAC_NZ) { | |||||
| // Clone primitive to modify it | |||||
| auto prim = GetCNodePrimitive(cnode); | |||||
| auto new_prim = std::make_shared<Primitive>(*prim); | |||||
| auto new_prim_node = NewValueNode(new_prim); | |||||
| cnode->set_input(0, new_prim_node); | |||||
| auto axis_value = new_prim->GetAttr(kAttrAxis); | |||||
| std::vector<int> default_axis; | |||||
| if (axis_value->isa<ValueList>()) { | |||||
| auto value_list = dyn_cast<ValueList>(axis_value); | |||||
| for (const auto &item : value_list->value()) { | |||||
| if (item->isa<Int32Imm>()) { | |||||
| default_axis.push_back(GetValue<int32_t>(item)); | |||||
| } | |||||
| } | |||||
| } else if (axis_value->isa<ValueTuple>()) { | |||||
| auto value_tuple = dyn_cast<ValueTuple>(axis_value); | |||||
| for (const auto &item : value_tuple->value()) { | |||||
| if (item->isa<Int32Imm>()) { | |||||
| default_axis.push_back(GetValue<int32_t>(item)); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Axis attr type is not correct!"; | |||||
| } | |||||
| auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); | |||||
| std::vector<int> frac_nz_axis = DefaultToFracNZAxis(infer_shape, default_axis); | |||||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<std::vector<int>>(frac_nz_axis), cnode); | |||||
| auto output_shape = AnfAlgo::GetOutputInferShape(cnode, 0); | |||||
| if (output_shape.size() == 1) { | |||||
| AnfAlgo::SetNodeAttr(kAttrOutputDefault, MakeValue<bool>(true), cnode); | |||||
| } | |||||
| } | |||||
| } | |||||
| void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, bool *use_same_format) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(default_format); | |||||
| MS_EXCEPTION_IF_NULL(use_same_format); | |||||
| std::unordered_map<std::string, size_t> all_input_formats; | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| for (size_t i = 0; i < input_num; ++i) { | |||||
| auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; | |||||
| MS_EXCEPTION_IF_NULL(input_kernel_node); | |||||
| if (!input_kernel_node->isa<Parameter>()) { | |||||
| auto pre_format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i); | |||||
| ++all_input_formats[pre_format]; | |||||
| continue; | |||||
| } | |||||
| auto para = input_kernel_node->cast<ParameterPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(para); | |||||
| if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) { | |||||
| auto pre_format = AnfAlgo::GetOutputFormat(para, 0); | |||||
| ++all_input_formats[pre_format]; | |||||
| continue; | |||||
| } | |||||
| *use_same_format = false; | |||||
| } | |||||
| if (all_input_formats.empty()) { | |||||
| // all inputs are parameter. | |||||
| *default_format = kOpFormat_NC1HWC0; | |||||
| } else { | |||||
| std::vector<std::pair<std::string, size_t>> pairs; | |||||
| for (auto iter = all_input_formats.begin(); iter != all_input_formats.end(); ++iter) { | |||||
| pairs.push_back(std::make_pair(iter->first, iter->second)); | |||||
| } | |||||
| auto cmp_func = [](const std::pair<std::string, size_t> &a, const std::pair<std::string, size_t> &b) { | |||||
| if (a.second != b.second) { | |||||
| return a.second > b.second; | |||||
| } else if (a.first == kOpFormat_DEFAULT) { | |||||
| return a.second + 1 > b.second; | |||||
| } else if (b.first == kOpFormat_DEFAULT) { | |||||
| return a.second > b.second + 1; | |||||
| } | |||||
| return a.second > b.second; | |||||
| }; | |||||
| std::sort(pairs.begin(), pairs.end(), cmp_func); | |||||
| *default_format = pairs.begin()->first; | |||||
| } | |||||
| for (size_t i = 0; i < input_num; ++i) { | |||||
| auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; | |||||
| MS_EXCEPTION_IF_NULL(input_kernel_node); | |||||
| if (!input_kernel_node->isa<Parameter>() || | |||||
| AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) != kTypeUnknown) { | |||||
| continue; | |||||
| } | |||||
| auto weight_infer_shape = AnfAlgo::GetOutputInferShape(input_kernel_node, 0); | |||||
| if (weight_infer_shape.size() < 2 && *default_format == kOpFormat_FRAC_NZ) { | |||||
| *default_format = kOpFormat_DEFAULT; | |||||
| *use_same_format = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| void UpdateGraphKernelInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list, | |||||
| const std::string &default_format, bool use_same_format, | |||||
| std::vector<std::string> *graph_input_format, | |||||
| std::vector<TypeId> *graph_input_type) { | |||||
| MS_EXCEPTION_IF_NULL(graph_input_format); | |||||
| MS_EXCEPTION_IF_NULL(graph_input_type); | |||||
| // We set same format to all inputs of graph kernel subgraph, and process this latter. | |||||
| // We set dtype to inputs of graph kernel subgraph same as infer dtypes. | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| for (size_t i = 0; i < input_num; ++i) { | |||||
| auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; | |||||
| MS_EXCEPTION_IF_NULL(input_kernel_node); | |||||
| if (use_same_format) { | |||||
| bool can_convert = true; | |||||
| if (default_format == kOpFormat_FRAC_NZ) { | |||||
| auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); | |||||
| if (!CanConvertDefaultShapeToNZ(infer_shape)) { | |||||
| MS_LOG(WARNING) << "Shape can't be converted to frac nz shape, so use default format instead"; | |||||
| can_convert = false; | |||||
| } | |||||
| } | |||||
| if (can_convert) { | |||||
| graph_input_format->push_back(default_format); | |||||
| } else { | |||||
| graph_input_format->push_back(kOpFormat_DEFAULT); | |||||
| } | |||||
| graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i)); | |||||
| continue; | |||||
| } | |||||
| if (!input_kernel_node->isa<Parameter>()) { | |||||
| // subgraph parameter from output of other nodes. | |||||
| graph_input_format->push_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i)); | |||||
| graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i)); | |||||
| continue; | |||||
| } | |||||
| auto para = input_kernel_node->cast<ParameterPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(para); | |||||
| if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) { | |||||
| // parameter already selected. | |||||
| graph_input_format->push_back(AnfAlgo::GetOutputFormat(para, 0)); | |||||
| graph_input_type->push_back(AnfAlgo::GetOutputDeviceDataType(para, 0)); | |||||
| continue; | |||||
| } | |||||
| // weight parameter. | |||||
| graph_input_format->push_back(default_format); | |||||
| graph_input_type->push_back(AnfAlgo::GetOutputInferDataType(input_kernel_node, 0)); | |||||
| } | |||||
| for (size_t i = 0; i < input_num; ++i) { | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| std::vector<std::string> outputs_format = {(*graph_input_format)[i]}; | |||||
| std::vector<TypeId> outputs_device_type = {(*graph_input_type)[i]}; | |||||
| builder.SetOutputsFormat(outputs_format); | |||||
| builder.SetOutputsDeviceType(outputs_device_type); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get()); | |||||
| } | |||||
| } | |||||
| void UpdateEquivFormat(const std::vector<std::pair<AnfNodePtr, size_t>> &output_index, | |||||
| const std::vector<AnfNodePtr> &node_list, const FuncGraphPtr &func_graph, | |||||
| const FuncGraphManagerPtr &mng) { | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| for (size_t i = 0; i < node_list.size(); ++i) { | |||||
| // select nodes in subgraph. | |||||
| auto anf_node = node_list[i]; | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| auto cnode = anf_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| cnode->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||||
| SelectKernelInfo(cnode, KernelType::AKG_KERNEL); | |||||
| // Update ReduceSum | |||||
| if (!IsPrimitiveCNode(cnode, prim::kPrimReduceSum)) { | |||||
| continue; | |||||
| } | |||||
| UpdateFracNZReduceOp(cnode); | |||||
| // If ReduceSum's output is 1d and not Default format, convert it to Default format | |||||
| auto out_format = AnfAlgo::GetOutputFormat(cnode, 0); | |||||
| if (out_format == kOpFormat_DEFAULT || !AnfAlgo::HasNodeAttr(kAttrOutputDefault, cnode)) { | |||||
| continue; | |||||
| } | |||||
| auto infer_shape = AnfAlgo::GetOutputInferShape(cnode, 0); | |||||
| // Insert EquivFormat node, then select kernel info again | |||||
| std::vector<AnfNodePtr> trans_inputs; | |||||
| trans_inputs.push_back(NewValueNode(prim::kPrimEquivFormat)); | |||||
| trans_inputs.push_back(cnode); | |||||
| CNodePtr trans_node = func_graph->NewCNode(trans_inputs); | |||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0)}, | |||||
| {AnfAlgo::GetOutputInferShape(cnode, 0)}, trans_node.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue<std::vector<std::string>>({"x"}), trans_node); | |||||
| if (trans_node->kernel_info() == nullptr) { | |||||
| trans_node->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||||
| } | |||||
| SelectKernelInfo(trans_node, KernelType::AKG_KERNEL); | |||||
| mng->Replace(cnode, trans_node); | |||||
| } | |||||
| } | |||||
| void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &node_list, | |||||
| const std::vector<AnfNodePtr> &input_list, const FuncGraphManagerPtr &mng, | |||||
| const std::string &default_format, std::vector<std::string> *graph_input_format, | |||||
| std::vector<TypeId> *graph_input_type) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| MS_EXCEPTION_IF_NULL(graph_input_format); | |||||
| MS_EXCEPTION_IF_NULL(graph_input_type); | |||||
| // update graph input format and dtype use inner ops. | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (graph_input_format->size() != input_num) { | |||||
| MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString() | |||||
| << "], [%" << graph_input_format->size() << "] != [%" << input_num << "]"; | |||||
| } | |||||
| std::vector<bool> need_update(input_num, false); | |||||
| auto &node_users = mng->node_users(); | |||||
| for (size_t i = 0; i < input_num; ++i) { | |||||
| auto &input = input_list[i]; | |||||
| auto iter = node_users.find(input); | |||||
| if (iter == node_users.end() || iter->second.empty()) { | |||||
| continue; | |||||
| } | |||||
| for (auto &node_user : iter->second) { | |||||
| if (node_user.first->kernel_info() == nullptr || | |||||
| node_user.first->kernel_info()->select_kernel_build_info() == nullptr) { | |||||
| // maybe not a real kernel. | |||||
| continue; | |||||
| } | |||||
| auto user_format = AnfAlgo::GetInputFormat(node_user.first, IntToSize(node_user.second - 1)); | |||||
| if (user_format != (*graph_input_format)[i]) { | |||||
| MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" | |||||
| << kernel_node->DebugString() | |||||
| << "] selected different format. we use defult: " << default_format; | |||||
| (*graph_input_format)[i] = default_format; | |||||
| need_update[i] = true; | |||||
| } | |||||
| if (kernel_node->input(i + 1)->isa<Parameter>()) { | |||||
| auto user_dtype = AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1)); | |||||
| if (user_dtype != (*graph_input_type)[i]) { | |||||
| TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0); | |||||
| MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" | |||||
| << kernel_node->DebugString() | |||||
| << "] selected different dtype. we use default: " << TypeIdLabel(default_dtype); | |||||
| (*graph_input_type)[i] = default_dtype; | |||||
| need_update[i] = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| for (size_t i = 0; i < input_num; ++i) { | |||||
| if (!need_update[i]) { | |||||
| continue; | |||||
| } | |||||
| need_update[i] = false; | |||||
| MS_LOG(DEBUG) << "Update input format: " << i << " of: [" << kernel_node->DebugString() | |||||
| << "] to: " << (*graph_input_format)[i]; | |||||
| MS_LOG(DEBUG) << "Update input dtype: " << i << " of: [" << kernel_node->DebugString() | |||||
| << "] to: " << TypeIdLabel((*graph_input_type)[i]); | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| std::vector<std::string> outputs_format = {(*graph_input_format)[i]}; | |||||
| std::vector<TypeId> outputs_device_type = {(*graph_input_type)[i]}; | |||||
| builder.SetOutputsFormat(outputs_format); | |||||
| builder.SetOutputsDeviceType(outputs_device_type); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get()); | |||||
| } | |||||
| ResetKernelBuildInfo(kernel_node); | |||||
| // select nodes in subgraph again. | |||||
| for (size_t i = 0; i < node_list.size(); ++i) { | |||||
| auto anf_node = node_list[i]; | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| auto cnode = anf_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| size_t cnode_input_num = AnfAlgo::GetInputTensorNum(cnode); | |||||
| for (size_t j = 0; j < cnode_input_num; ++j) { | |||||
| auto input_node = cnode->input(j + 1); | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| if (!IsValueNode<tensor::Tensor>(input_node)) { | |||||
| continue; | |||||
| } | |||||
| // reset format and dtype of const tensor. | |||||
| builder.SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT}); | |||||
| builder.SetOutputsDeviceType(std::vector<TypeId>{kTypeUnknown}); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_node.get()); | |||||
| } | |||||
| SelectKernelInfo(node_list[i]->cast<CNodePtr>(), KernelType::AKG_KERNEL); | |||||
| } | |||||
| } | |||||
| void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector<std::pair<AnfNodePtr, size_t>> &output_index, | |||||
| const std::vector<std::string> &graph_input_format, | |||||
| const std::vector<TypeId> &graph_input_type) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| std::vector<std::string> graph_output_format; | |||||
| std::vector<TypeId> graph_output_type; | |||||
| for (size_t i = 0; i < output_index.size(); ++i) { | |||||
| auto const &output = output_index[i]; | |||||
| graph_output_format.push_back(AnfAlgo::GetOutputFormat(output.first, output.second)); | |||||
| TypeId output_type(kTypeUnknown); | |||||
| if (output.first->isa<CNode>()) { | |||||
| output_type = AnfAlgo::GetCNodeOutputPrecision(output.first); | |||||
| } | |||||
| if (output_type == kTypeUnknown) { | |||||
| output_type = AnfAlgo::GetOutputDeviceDataType(output.first, output.second); | |||||
| } | |||||
| graph_output_type.push_back(output_type); | |||||
| } | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder; | |||||
| graph_info_builder.SetInputsFormat(graph_input_format); | |||||
| graph_info_builder.SetInputsDeviceType(graph_input_type); | |||||
| graph_info_builder.SetOutputsFormat(graph_output_format); | |||||
| graph_info_builder.SetOutputsDeviceType(graph_output_type); | |||||
| graph_info_builder.SetProcessor(kernel::Processor::AICORE); | |||||
| graph_info_builder.SetKernelType(KernelType::AKG_KERNEL); | |||||
| graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE); | |||||
| auto graph_selected_info = graph_info_builder.Build(); | |||||
| MS_EXCEPTION_IF_NULL(graph_selected_info); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get()); | |||||
| SetTensorDeviceInfo(*graph_selected_info, kernel_node); | |||||
| } | |||||
| void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| // collect input info of funcgraph | |||||
| std::vector<AnfNodePtr> node_list; | |||||
| std::vector<AnfNodePtr> input_list; | |||||
| std::vector<AnfNodePtr> output_list; | |||||
| kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); | |||||
| if (input_list.size() != kernel_node->inputs().size() - 1) { | |||||
| MS_EXCEPTION(ArgumentError) << "Input num of funcgraph[" << func_graph->ToString() << "] not equal input of cnode[" | |||||
| << kernel_node->DebugString() << "], [%" << input_list.size() << "] != [" | |||||
| << kernel_node->inputs().size() << "]"; | |||||
| } | |||||
| std::string default_format; | |||||
| bool use_same_format = true; | |||||
| GetDefaultFormat(kernel_node, &default_format, &use_same_format); | |||||
| MS_LOG(DEBUG) << "GraphKernel[" << func_graph->ToString() << "] use same input format[" << default_format | |||||
| << "] for ParameterWeight."; | |||||
| std::vector<std::string> graph_input_format; | |||||
| std::vector<TypeId> graph_input_type; | |||||
| UpdateGraphKernelInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format, | |||||
| &graph_input_type); | |||||
| auto mng = func_graph->manager(); | |||||
| if (mng == nullptr) { | |||||
| mng = Manage(func_graph, true); | |||||
| } | |||||
| auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list); | |||||
| UpdateEquivFormat(output_index, node_list, func_graph, mng); | |||||
| node_list.clear(); | |||||
| input_list.clear(); | |||||
| output_list.clear(); | |||||
| kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); | |||||
| // update graph input format and dtype use inner ops. | |||||
| UpdateFormatsAndDtypes(kernel_node, node_list, input_list, mng, default_format, &graph_input_format, | |||||
| &graph_input_type); | |||||
| // set fix_precision for kernel when the me prim has fix_precision attr | |||||
| UpdateKernelInfo(node_list); | |||||
| output_index = kernel::GetOutputIndex(node_list, input_list, output_list); | |||||
| SetGraphKernelInfo(kernel_node, output_index, graph_input_format, graph_input_type); | |||||
| } | |||||
| } // namespace ascend | |||||
| } // namespace device | |||||
| } // namespace mindspore | |||||
| @@ -24,7 +24,7 @@ namespace device { | |||||
| namespace ascend { | namespace ascend { | ||||
| void GraphDescReporter::ReportData() { | void GraphDescReporter::ReportData() { | ||||
| for (const auto &node : cnode_list_) { | for (const auto &node : cnode_list_) { | ||||
| if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AUTO_DIFF_KERNEL) { | |||||
| if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) { | |||||
| MS_LOG(WARNING) << "Skip non tbe kernel"; | MS_LOG(WARNING) << "Skip non tbe kernel"; | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -31,7 +31,7 @@ void TaskDescReporter::ReportData() { | |||||
| size_t task_index = 0; | size_t task_index = 0; | ||||
| for (const auto &node : cnode_list_) { | for (const auto &node : cnode_list_) { | ||||
| if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AUTO_DIFF_KERNEL) { | |||||
| if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) { | |||||
| MS_LOG(WARNING) << "Skip non tbe kernel"; | MS_LOG(WARNING) << "Skip non tbe kernel"; | ||||
| ++task_index; | ++task_index; | ||||
| continue; | continue; | ||||
| @@ -43,7 +43,37 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve | |||||
| void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) { | void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) { | ||||
| MS_EXCEPTION_IF_NULL(anf_node_ptr); | MS_EXCEPTION_IF_NULL(anf_node_ptr); | ||||
| if (anf_node_ptr->inputs().size() != 2) { | if (anf_node_ptr->inputs().size() != 2) { | ||||
| MS_LOG(EXCEPTION) << "atomic Addr clean Node Input nodes not equal 2."; | |||||
| // akg process | |||||
| // set atomic clean addr | |||||
| if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, anf_node_ptr)) { | |||||
| auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(anf_node_ptr, kAttrAutomicOutputIndexs); | |||||
| auto graph = anf_node_ptr->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto node_users = manager->node_users(); | |||||
| if (node_users[anf_node_ptr].empty()) { | |||||
| MS_LOG(EXCEPTION) << "Node users of " << anf_node_ptr->ToString() << " is empty."; | |||||
| } | |||||
| auto depend_node = node_users[anf_node_ptr].pop().first; | |||||
| if (!IsPrimitiveCNode(depend_node, prim::kPrimDepend)) { | |||||
| MS_LOG(EXCEPTION) << "Checking Depend node failed"; | |||||
| } | |||||
| if (node_users[depend_node].empty()) { | |||||
| MS_LOG(EXCEPTION) << "Node users of " << depend_node->ToString() << " is empty."; | |||||
| } | |||||
| auto post_node = node_users[depend_node].pop().first; | |||||
| for (auto index : clean_output_indexs) { | |||||
| auto device_address = AnfAlgo::GetOutputAddr(post_node, index); | |||||
| kernel::AddressPtr input = std::make_shared<kernel::Address>(); | |||||
| input->addr = device_address->ptr_; | |||||
| MS_EXCEPTION_IF_NULL(input->addr); | |||||
| input->size = device_address->size_; | |||||
| kernel_inputs->push_back(input); | |||||
| } | |||||
| MS_LOG(DEBUG) << "AtomicAddClean clean output size: " << clean_output_indexs.size(); | |||||
| } | |||||
| return; | |||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]); | MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]); | ||||
| auto pre_node = (anf_node_ptr->inputs()[1])->cast<CNodePtr>(); | auto pre_node = (anf_node_ptr->inputs()[1])->cast<CNodePtr>(); | ||||
| @@ -59,7 +89,7 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP | |||||
| input->size = device_address->size_; | input->size = device_address->size_; | ||||
| kernel_inputs->push_back(input); | kernel_inputs->push_back(input); | ||||
| } | } | ||||
| MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); | |||||
| MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); | |||||
| } | } | ||||
| // set clean workspace address | // set clean workspace address | ||||
| if (AnfAlgo::HasNodeAttr(kAttrAutomicWorkspaceSize, pre_node)) { | if (AnfAlgo::HasNodeAttr(kAttrAutomicWorkspaceSize, pre_node)) { | ||||
| @@ -16,7 +16,7 @@ | |||||
| #include "device/gpu/gpu_kernel_build.h" | #include "device/gpu/gpu_kernel_build.h" | ||||
| #include <string> | #include <string> | ||||
| #include "kernel/kernel.h" | #include "kernel/kernel.h" | ||||
| #include "kernel/akg/akgkernelbuild.h" | |||||
| #include "kernel/akg/akg_kernel_build.h" | |||||
| #include "kernel/akg/gpu/akg_gpu_kernel_build.h" | #include "kernel/akg/gpu/akg_gpu_kernel_build.h" | ||||
| #include "kernel/gpu/gpu_kernel_factory.h" | #include "kernel/gpu/gpu_kernel_factory.h" | ||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| @@ -37,7 +37,7 @@ void GpuBuild(const KernelGraphPtr &kernel_graph) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AUTO_DIFF_KERNEL) { | |||||
| if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AKG_KERNEL) { | |||||
| auto gpu_kernel_ptr = kernel::AkgGpuKernelBuild(kernel); | auto gpu_kernel_ptr = kernel::AkgGpuKernelBuild(kernel); | ||||
| if (!gpu_kernel_ptr) { | if (!gpu_kernel_ptr) { | ||||
| MS_LOG(EXCEPTION) << "Build akg kernel op[" << kernel_name << "] failed"; | MS_LOG(EXCEPTION) << "Build akg kernel op[" << kernel_name << "] failed"; | ||||
| @@ -184,7 +184,7 @@ void SetKernelInfo(const CNodePtr &kernel_node) { | |||||
| if (!result) { | if (!result) { | ||||
| result = SelectAkgKernel(kernel_node, builder->Build()); | result = SelectAkgKernel(kernel_node, builder->Build()); | ||||
| kernel_type = AUTO_DIFF_KERNEL; | |||||
| kernel_type = AKG_KERNEL; | |||||
| } | } | ||||
| if (!result) { | if (!result) { | ||||
| @@ -26,6 +26,8 @@ | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive_base.h" | #include "ir/primitive_base.h" | ||||
| #include "operator/ops.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // namespace to support intermediate representation definition | // namespace to support intermediate representation definition | ||||
| CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph) | CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph) | ||||
| @@ -106,10 +108,14 @@ std::string ValueNode::fullname_with_scope() { | |||||
| bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) { | bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| if (cnode != nullptr) { | |||||
| if (cnode == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if (value != nullptr) { | |||||
| return cnode->IsApply(value); | return cnode->IsApply(value); | ||||
| } | } | ||||
| return false; | |||||
| const auto &prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||||
| return prim != nullptr; | |||||
| } | } | ||||
| PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) { | PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) { | ||||
| @@ -124,6 +124,7 @@ class AnfNode : public Base { | |||||
| const KernelInfoDevice *kernel_info() const { return kernel_info_.get(); } | const KernelInfoDevice *kernel_info() const { return kernel_info_.get(); } | ||||
| KernelInfoDevice *kernel_info() { return kernel_info_.get(); } | KernelInfoDevice *kernel_info() { return kernel_info_.get(); } | ||||
| const KernelInfoDevicePtr &kernel_info_ptr() { return kernel_info_; } | |||||
| void set_kernel_info(const KernelInfoDevicePtr &kernel_info) { kernel_info_ = kernel_info; } | void set_kernel_info(const KernelInfoDevicePtr &kernel_info) { kernel_info_ = kernel_info; } | ||||
| AbstractBasePtr abstract() const { return abstract_; } | AbstractBasePtr abstract() const { return abstract_; } | ||||
| @@ -395,9 +396,9 @@ static S GetValue(const ValuePtr &value) { | |||||
| std::string GetCNodeFuncName(CNodePtr cnode); | std::string GetCNodeFuncName(CNodePtr cnode); | ||||
| // used to check whether an AnfNode is a cnode with a kind of Primitive as first input | // used to check whether an AnfNode is a cnode with a kind of Primitive as first input | ||||
| bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value); | |||||
| bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value = nullptr); | |||||
| // used to check whether an AnfNode is a cnode with a Primitive as first input | |||||
| // used to get PrimitivePtr from a cnode first input | |||||
| PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node); | PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node); | ||||
| // used to check whether an AnfNode is a valuenode having some Primitive value | // used to check whether an AnfNode is a valuenode having some Primitive value | ||||
| @@ -70,7 +70,7 @@ std::string CNode::fullname_with_scope() { | |||||
| } | } | ||||
| fullname_with_scope_ = name; | fullname_with_scope_ = name; | ||||
| } else { | } else { | ||||
| // cnode input 0 should be primitive ptr | |||||
| // cnode input 0 should be primitive ptr or funcgraph ptr | |||||
| auto value_ptr = input(0)->cast<ValueNodePtr>(); | auto value_ptr = input(0)->cast<ValueNodePtr>(); | ||||
| if (value_ptr == nullptr) { | if (value_ptr == nullptr) { | ||||
| MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << "."; | MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << "."; | ||||
| @@ -84,11 +84,23 @@ std::string CNode::fullname_with_scope() { | |||||
| return fullname_with_scope_; | return fullname_with_scope_; | ||||
| } | } | ||||
| PrimitivePtr prim = GetValue<PrimitivePtr>(input_value); | |||||
| auto prim = input_value->cast<PrimitivePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(scope()); | MS_EXCEPTION_IF_NULL(scope()); | ||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| fullname_with_scope_ = | |||||
| scope()->name() + "/" + prim->name() + "-op" + id_generator::get_id(shared_from_base<CNode>()); | |||||
| fullname_with_scope_ = scope()->name() + "/"; | |||||
| if (prim != nullptr) { | |||||
| fullname_with_scope_ += prim->name(); | |||||
| } else { | |||||
| auto func_graph = input_value->cast<FuncGraphPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto fg_flag = func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||||
| if (fg_flag != nullptr) { | |||||
| auto fg_name = GetValue<std::string>(fg_flag); | |||||
| fullname_with_scope_ += "GraphKernel_" + fg_name; | |||||
| } else { | |||||
| fullname_with_scope_ += func_graph->ToString(); | |||||
| } | |||||
| } | |||||
| fullname_with_scope_ += "-op" + id_generator::get_id(shared_from_base<CNode>()); | |||||
| } | } | ||||
| return fullname_with_scope_; | return fullname_with_scope_; | ||||
| @@ -77,9 +77,9 @@ class Bool : public Number { | |||||
| TypeId generic_type_id() const override { return kNumberTypeBool; } | TypeId generic_type_id() const override { return kNumberTypeBool; } | ||||
| TypePtr DeepCopy() const override { return std::make_shared<Bool>(); } | TypePtr DeepCopy() const override { return std::make_shared<Bool>(); } | ||||
| std::string ToString() const override { return "Bool_"; } | |||||
| std::string ToReprString() const override { return "bool_"; } | |||||
| std::string DumpText() const override { return "Bool_"; } | |||||
| std::string ToString() const override { return "Bool"; } | |||||
| std::string ToReprString() const override { return "bool"; } | |||||
| std::string DumpText() const override { return "Bool"; } | |||||
| }; | }; | ||||
| // Int | // Int | ||||
| @@ -34,7 +34,7 @@ namespace mindspore { | |||||
| * Methods of Graph | * Methods of Graph | ||||
| */ | */ | ||||
| FuncGraph::FuncGraph() | FuncGraph::FuncGraph() | ||||
| : flags_(), | |||||
| : attrs_(), | |||||
| transforms_(), | transforms_(), | ||||
| parameter_default_value_(), | parameter_default_value_(), | ||||
| seen_(0), | seen_(0), | ||||
| @@ -95,13 +95,27 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { | |||||
| return p; | return p; | ||||
| } | } | ||||
| bool FuncGraph::has_flag(const std::string &flag) { | |||||
| if (flags_.count(flag)) { | |||||
| return flags_[flag]; | |||||
| bool FuncGraph::has_flag(const std::string &key) { | |||||
| auto iter = attrs_.find(key); | |||||
| if (iter != attrs_.cend()) { | |||||
| if (iter->second->isa<BoolImm>()) { | |||||
| return GetValue<bool>(iter->second); | |||||
| } | |||||
| MS_LOG(WARNING) << "key " << key << " is not a flag, please use has_attr function."; | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| bool FuncGraph::has_attr(const std::string &key) { | |||||
| auto iter = attrs_.find(key); | |||||
| return !(iter == attrs_.cend()); | |||||
| } | |||||
| ValuePtr FuncGraph::get_attr(const std::string &key) { | |||||
| auto iter = attrs_.find(key); | |||||
| return iter == attrs_.cend() ? nullptr : iter->second; | |||||
| } | |||||
| CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | ||||
| CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>()); | CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>()); | ||||
| if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | ||||
| @@ -74,6 +74,7 @@ using FuncGraphMap = OrderedMap<FuncGraphPtr, int>; | |||||
| const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; | const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; | ||||
| const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; | const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; | ||||
| const char FUNC_GRAPH_FLAG_CORE[] = "core"; | const char FUNC_GRAPH_FLAG_CORE[] = "core"; | ||||
| const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel"; | |||||
| const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param"; | const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param"; | ||||
| namespace abstract { | namespace abstract { | ||||
| @@ -195,10 +196,19 @@ class FuncGraph : public FuncGraphBase { | |||||
| void set_is_generate(bool generated) { is_generated_ = generated; } | void set_is_generate(bool generated) { is_generated_ = generated; } | ||||
| bool is_generated() const { return is_generated_; } | bool is_generated() const { return is_generated_; } | ||||
| bool has_flag(const std::string &flag); | |||||
| std::unordered_map<std::string, bool> &flags() { return flags_; } | |||||
| void set_flags(const std::unordered_map<std::string, bool> &flags) { flags_ = flags; } | |||||
| void set_flags(const std::string &key, const bool value) { flags_[key] = value; } | |||||
| std::unordered_map<std::string, ValuePtr> &attrs() { return attrs_; } | |||||
| void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) { | |||||
| for (auto &attr : attrs) { | |||||
| attrs_[attr.first] = attr.second; | |||||
| } | |||||
| } | |||||
| bool has_flag(const std::string &key); | |||||
| void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); } | |||||
| void erase_flag(const std::string &key) { (void)attrs_.erase(key); } | |||||
| bool has_attr(const std::string &key); | |||||
| ValuePtr get_attr(const std::string &key); | |||||
| void set_attr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; } | |||||
| std::unordered_map<std::string, FuncGraphTransform> &transforms() { return transforms_; } | std::unordered_map<std::string, FuncGraphTransform> &transforms() { return transforms_; } | ||||
| void set_transforms(const std::unordered_map<std::string, FuncGraphTransform> &transforms) { | void set_transforms(const std::unordered_map<std::string, FuncGraphTransform> &transforms) { | ||||
| @@ -317,7 +327,7 @@ class FuncGraph : public FuncGraphBase { | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; } | std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; } | ||||
| std::unordered_map<std::string, bool> flags_; | |||||
| std::unordered_map<std::string, ValuePtr> attrs_; | |||||
| std::unordered_map<std::string, FuncGraphTransform> transforms_; | std::unordered_map<std::string, FuncGraphTransform> transforms_; | ||||
| // parameter default value | // parameter default value | ||||
| std::map<std::string, AnfNodePtr> parameter_default_value_; | std::map<std::string, AnfNodePtr> parameter_default_value_; | ||||
| @@ -90,6 +90,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { | |||||
| new_node->set_abstract(old_node->abstract()); | new_node->set_abstract(old_node->abstract()); | ||||
| ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); | ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); | ||||
| new_node->set_scope(scope); | new_node->set_scope(scope); | ||||
| new_node->set_kernel_info(old_node->kernel_info_ptr()); | |||||
| repl_node_[old_node] = new_node; | repl_node_[old_node] = new_node; | ||||
| nodes_.emplace_back(old_node, new_node); | nodes_.emplace_back(old_node, new_node); | ||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| @@ -211,7 +212,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons | |||||
| MS_EXCEPTION_IF_NULL(target_func_graph); | MS_EXCEPTION_IF_NULL(target_func_graph); | ||||
| TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); | TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); | ||||
| *target_func_graph = std::make_shared<FuncGraph>(); | *target_func_graph = std::make_shared<FuncGraph>(); | ||||
| (*target_func_graph)->set_flags(func_graph->flags()); | |||||
| (*target_func_graph)->set_attrs(func_graph->attrs()); | |||||
| (*target_func_graph)->set_transforms(func_graph->transforms()); | (*target_func_graph)->set_transforms(func_graph->transforms()); | ||||
| (*target_func_graph)->set_has_vararg(func_graph->has_vararg()); | (*target_func_graph)->set_has_vararg(func_graph->has_vararg()); | ||||
| (*target_func_graph)->set_has_kwarg(func_graph->has_kwarg()); | (*target_func_graph)->set_has_kwarg(func_graph->has_kwarg()); | ||||
| @@ -636,9 +637,14 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP | |||||
| if (MsContext::GetInstance()->is_multi_graph_sink()) { | if (MsContext::GetInstance()->is_multi_graph_sink()) { | ||||
| if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { | if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { | ||||
| new_func_graph->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||||
| new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||||
| } | } | ||||
| } | } | ||||
| if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||||
| new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||||
| } | |||||
| return new_func_graph; | return new_func_graph; | ||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -399,8 +399,8 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() { | |||||
| depend_inputs.push_back(*iter); | depend_inputs.push_back(*iter); | ||||
| } | } | ||||
| } | } | ||||
| set_flags(GRAPH_FLAG_HAS_EFFECT, false); | |||||
| set_flags(GRAPH_FLAG_EFFECT_PATIAL_ORDER, true); | |||||
| set_flag(GRAPH_FLAG_HAS_EFFECT, false); | |||||
| set_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER, true); | |||||
| if (!depend_inputs.empty()) { | if (!depend_inputs.empty()) { | ||||
| SetEffectDepends(depend_inputs); | SetEffectDepends(depend_inputs); | ||||
| } | } | ||||
| @@ -9,6 +9,10 @@ if (ENABLE_D) | |||||
| file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | ||||
| "kernel_query.cc" | "kernel_query.cc" | ||||
| "kernel_fusion.cc" | "kernel_fusion.cc" | ||||
| "akg/ascend/*.cc" | |||||
| "akg/akg_kernel_build.cc" | |||||
| "akg/akg_kernel_attrs_process.cc" | |||||
| "akg/akg_kernel_metadata.cc" | |||||
| "tbe/*.cc" | "tbe/*.cc" | ||||
| "aicpu/*.cc" | "aicpu/*.cc" | ||||
| "rts/*.cc" | "rts/*.cc" | ||||
| @@ -33,7 +37,7 @@ if (ENABLE_GPU) | |||||
| file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | ||||
| "gpu/*.cu" | "gpu/*.cu" | ||||
| "akg/gpu/*.cc" | "akg/gpu/*.cc" | ||||
| "akg/akgkernelbuild.cc" | |||||
| "akg/akg_kernel_build.cc" | |||||
| "akg/akg_kernel_attrs_process.cc" | "akg/akg_kernel_attrs_process.cc" | ||||
| ) | ) | ||||
| @@ -24,7 +24,7 @@ | |||||
| #include <map> | #include <map> | ||||
| #include "device/kernel_runtime.h" | #include "device/kernel_runtime.h" | ||||
| #include "kernel/aicpu/aicpu_kernel_mod.h" | #include "kernel/aicpu/aicpu_kernel_mod.h" | ||||
| #include "kernel/akg/akgkernelbuild.h" | |||||
| #include "kernel/akg/akg_kernel_build.h" | |||||
| #include "proto/tensor.pb.h" | #include "proto/tensor.pb.h" | ||||
| #include "proto/tensor_shape.pb.h" | #include "proto/tensor_shape.pb.h" | ||||
| #include "proto/attr.pb.h" | #include "proto/attr.pb.h" | ||||
| @@ -79,6 +79,10 @@ void SetAkgAttrsForCast(const AnfNodePtr &anf_node) { | |||||
| dst_type = "float32"; | dst_type = "float32"; | ||||
| } else if (output_type == kFloat16->type_id()) { | } else if (output_type == kFloat16->type_id()) { | ||||
| dst_type = "float16"; | dst_type = "float16"; | ||||
| } else if (output_type == kInt32->type_id()) { | |||||
| dst_type = "int32"; | |||||
| } else { | |||||
| MS_LOG(WARNING) << "Unknown cast_to type: " << TypeIdToType(output_type)->ToString(); | |||||
| } | } | ||||
| AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); | AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); | ||||
| } | } | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "kernel/akg/akgkernelbuild.h" | |||||
| #include "kernel/akg/akg_kernel_build.h" | |||||
| #include <Python.h> | #include <Python.h> | ||||
| #include <sys/types.h> | #include <sys/types.h> | ||||
| #include <signal.h> | #include <signal.h> | ||||
| @@ -43,7 +43,9 @@ namespace kernel { | |||||
| constexpr int ME_MAX_KERNEL_NAME_LENGTH = 200; | constexpr int ME_MAX_KERNEL_NAME_LENGTH = 200; | ||||
| constexpr int32_t ARGS_SIZE = 1; | constexpr int32_t ARGS_SIZE = 1; | ||||
| constexpr auto kCompileWithJsonFunc = "compilewithjson"; | constexpr auto kCompileWithJsonFunc = "compilewithjson"; | ||||
| // json key | // json key | ||||
| constexpr auto kOpDesc = "op_desc"; | |||||
| constexpr auto kInputDesc = "input_desc"; | constexpr auto kInputDesc = "input_desc"; | ||||
| constexpr auto kShape = "shape"; | constexpr auto kShape = "shape"; | ||||
| constexpr auto kDataType = "data_type"; | constexpr auto kDataType = "data_type"; | ||||
| @@ -51,13 +53,24 @@ constexpr auto kOutputDesc = "output_desc"; | |||||
| constexpr auto kName = "name"; | constexpr auto kName = "name"; | ||||
| constexpr auto kTensorName = "tensor_name"; | constexpr auto kTensorName = "tensor_name"; | ||||
| constexpr auto kValue = "value"; | constexpr auto kValue = "value"; | ||||
| constexpr auto KInpputNames = "input_names"; | |||||
| constexpr auto KDynInputSizes = "dyn_input_sizes"; | |||||
| constexpr auto KInputNames = "input_names"; | |||||
| constexpr auto KInput = "input"; | constexpr auto KInput = "input"; | ||||
| constexpr auto KDtype = "dtype"; | constexpr auto KDtype = "dtype"; | ||||
| int AkgKernelBuild::op_cnt_ = 0; | |||||
| std::mutex AkgKernelBuild::op_cnt_mtx_; | |||||
| namespace { | |||||
| template <typename T> | |||||
| std::string Vector2Str(const std::vector<T> &inputs) { | |||||
| if (!inputs.empty()) { | |||||
| std::ostringstream oss; | |||||
| (void)std::copy(inputs.begin(), inputs.end() - 1, std::ostream_iterator<T>(oss, ", ")); | |||||
| oss << inputs.back(); | |||||
| return oss.str(); | |||||
| } | |||||
| return ""; | |||||
| } | |||||
| } // namespace | |||||
| std::string PyObjectToStr(PyObject *const PyObj) { | |||||
| std::string AkgKernelBuild::PyObjectToStr(PyObject *const PyObj) { | |||||
| char *pChar = nullptr; | char *pChar = nullptr; | ||||
| std::string str_res; | std::string str_res; | ||||
| if (PyObj == nullptr) { | if (PyObj == nullptr) { | ||||
| @@ -76,6 +89,72 @@ std::string PyObjectToStr(PyObject *const PyObj) { | |||||
| return str_res; | return str_res; | ||||
| } | } | ||||
| std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag, | |||||
| const std::pair<size_t, size_t> &position) { | |||||
| if (node_json.count(tag) == 0) { | |||||
| MS_LOG(ERROR) << "Node [" << node_json.dump() << "] has no key [" << tag << "]."; | |||||
| return ""; | |||||
| } | |||||
| auto const &tag_desc = node_json[tag]; | |||||
| nlohmann::json first_index; | |||||
| if (tag == kOutputDesc) { | |||||
| first_index = tag_desc; | |||||
| } else if (!tag_desc.is_array() || tag_desc.size() <= position.first) { | |||||
| MS_LOG(ERROR) << "Node [" << tag_desc.dump() << "] has no enough value [" << position.first << "]."; | |||||
| return ""; | |||||
| } else { | |||||
| first_index = tag_desc[position.first]; | |||||
| } | |||||
| if (!first_index.is_array() || first_index.size() <= position.second) { | |||||
| MS_LOG(ERROR) << "Node [" << first_index.dump() << "] has no enough value [" << position.second << "]."; | |||||
| return ""; | |||||
| } | |||||
| auto const &second_index = first_index[position.second]; | |||||
| if (second_index.count(kTensorName) == 0) { | |||||
| MS_LOG(ERROR) << "Node [" << second_index.dump() << "] has no key [" << kTensorName << "]."; | |||||
| return ""; | |||||
| } | |||||
| return second_index[kTensorName]; | |||||
| } | |||||
| void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position, | |||||
| nlohmann::json *const node_json) { | |||||
| MS_EXCEPTION_IF_NULL(node_json); | |||||
| if (node_json->count(tag) == 0) { | |||||
| MS_LOG(ERROR) << "Node [" << node_json->dump() << "] has no key [" << tag << "]."; | |||||
| return; | |||||
| } | |||||
| nlohmann::json *tag_desc = &((*node_json)[tag]); | |||||
| nlohmann::json *first_index; | |||||
| if (tag == kOutputDesc) { | |||||
| first_index = tag_desc; | |||||
| } else if (!tag_desc->is_array() || tag_desc->size() <= position.first) { | |||||
| MS_LOG(ERROR) << "Node [" << tag_desc->dump() << "] has no enough value [" << position.first << "]."; | |||||
| return; | |||||
| } else { | |||||
| first_index = &((*tag_desc)[position.first]); | |||||
| } | |||||
| if (!first_index->is_array() || first_index->size() <= position.second) { | |||||
| MS_LOG(ERROR) << "Node [" << first_index->dump() << "] has no enough value [" << position.second << "]."; | |||||
| return; | |||||
| } | |||||
| nlohmann::json *second_index = &((*first_index)[position.second]); | |||||
| if (second_index->count(kTensorName) == 0) { | |||||
| MS_LOG(ERROR) << "Node [" << second_index->dump() << "] has no key [" << kTensorName << "]."; | |||||
| return; | |||||
| } | |||||
| (*second_index)[kTensorName] = new_name; | |||||
| return; | |||||
| } | |||||
| int AkgKernelBuild::op_cnt_ = 0; | |||||
| std::mutex AkgKernelBuild::op_cnt_mtx_; | |||||
| std::string AkgKernelBuild::GetProcessor(const AnfNodePtr &anf_node) { | std::string AkgKernelBuild::GetProcessor(const AnfNodePtr &anf_node) { | ||||
| MS_EXCEPTION_IF_NULL(anf_node); | MS_EXCEPTION_IF_NULL(anf_node); | ||||
| std::string device; | std::string device; | ||||
| @@ -187,10 +266,7 @@ bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::j | |||||
| for (size_t input_i = 0; input_i < input_tensor_num; input_i++) { | for (size_t input_i = 0; input_i < input_tensor_num; input_i++) { | ||||
| // dtype : float16 | // dtype : float16 | ||||
| auto type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index); | auto type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index); | ||||
| TypePtr type_ptr = TypeIdToType(type_id); | |||||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||||
| std::string dtype = type_ptr->ToString(); | |||||
| dtype = Dtype2String(dtype); | |||||
| std::string dtype = TypeId2String(type_id); | |||||
| if (dtype.empty()) { | if (dtype.empty()) { | ||||
| MS_LOG(ERROR) << "Op [" << op_name << "] input [" << input_i << "] data type is null. "; | MS_LOG(ERROR) << "Op [" << op_name << "] input [" << input_i << "] data type is null. "; | ||||
| return false; | return false; | ||||
| @@ -198,13 +274,23 @@ bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::j | |||||
| nlohmann::json input_desc_json; | nlohmann::json input_desc_json; | ||||
| input_desc_json[kDataType] = dtype; | input_desc_json[kDataType] = dtype; | ||||
| input_desc_json[kName] = op_input_name; | input_desc_json[kName] = op_input_name; | ||||
| input_desc_json[kTensorName] = | |||||
| op_input_name + "_" + std::to_string(real_input_index) + "_" + std::to_string(input_i); | |||||
| input_desc_json[kShape] = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index); | |||||
| input_desc_json[kTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); | |||||
| auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index); | |||||
| if (GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { | |||||
| MS_LOG(WARNING) << "we take input[" << real_input_index << "] of [" << anf_node->DebugString(2) | |||||
| << "] as const tensor, shape: [" << Vector2Str(input_shape) | |||||
| << "], value: " << input_desc_json[kValue]; | |||||
| input_shape.clear(); | |||||
| } | |||||
| if (input_shape.empty()) { | |||||
| input_shape.push_back(1); | |||||
| } | |||||
| input_desc_json[kShape] = input_shape; | |||||
| input_list.emplace_back(input_desc_json); | input_list.emplace_back(input_desc_json); | ||||
| real_input_index++; | |||||
| } | } | ||||
| inputs_json->emplace_back(input_list); | inputs_json->emplace_back(input_list); | ||||
| real_input_index++; | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -220,10 +306,7 @@ bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann:: | |||||
| for (size_t i = 0; i < output_tensor_num; i++) { | for (size_t i = 0; i < output_tensor_num; i++) { | ||||
| nlohmann::json output_json; | nlohmann::json output_json; | ||||
| auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, i); | auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, i); | ||||
| TypePtr type_ptr = TypeIdToType(type_id); | |||||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||||
| std::string dtype = type_ptr->ToString(); | |||||
| dtype = Dtype2String(dtype); | |||||
| std::string dtype = TypeId2String(type_id); | |||||
| if (dtype.empty()) { | if (dtype.empty()) { | ||||
| MS_LOG(ERROR) << "Op [" << op_name << "] output [" << i << "] data type is null. "; | MS_LOG(ERROR) << "Op [" << op_name << "] output [" << i << "] data type is null. "; | ||||
| return false; | return false; | ||||
| @@ -232,7 +315,7 @@ bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann:: | |||||
| std::string output_name = outputs[i]->name(); | std::string output_name = outputs[i]->name(); | ||||
| output_json[kDataType] = dtype; | output_json[kDataType] = dtype; | ||||
| output_json[kName] = output_name; | output_json[kName] = output_name; | ||||
| output_json[kTensorName] = output_name + "_" + std::to_string(i); | |||||
| output_json[kTensorName] = "output_" + std::to_string(i) + "_" + std::to_string(GetOutputTensorIdxInc()); | |||||
| output_json[kShape] = AnfAlgo::GetOutputDeviceShape(anf_node, i); | output_json[kShape] = AnfAlgo::GetOutputDeviceShape(anf_node, i); | ||||
| outputs_json->push_back(output_json); | outputs_json->push_back(output_json); | ||||
| } | } | ||||
| @@ -358,15 +441,14 @@ bool AkgKernelBuild::GenerateSingleKernelJson(const AnfNodePtr &anf_node, const | |||||
| MS_EXCEPTION_IF_NULL(op_info_ptr); | MS_EXCEPTION_IF_NULL(op_info_ptr); | ||||
| // get basic params from currentNodeOpDesc | // get basic params from currentNodeOpDesc | ||||
| (*node_json)["platform"] = "AKG"; | |||||
| (*node_json)[kName] = op_name; | (*node_json)[kName] = op_name; | ||||
| (*node_json)["fusion_type"] = AnfAlgo::GetFusionType(anf_node); | |||||
| (*node_json)["impl_path"] = op_info_ptr->impl_path(); | (*node_json)["impl_path"] = op_info_ptr->impl_path(); | ||||
| (*node_json)["process"] = AkgKernelBuild::GetProcessor(anf_node); | (*node_json)["process"] = AkgKernelBuild::GetProcessor(anf_node); | ||||
| (*node_json)["composite"] = false; | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| ValuePtr input_names_v = primitive->GetAttr(KInpputNames); | |||||
| ValuePtr input_names_v = primitive->GetAttr(KInputNames); | |||||
| if (input_names_v == nullptr) { | if (input_names_v == nullptr) { | ||||
| MS_LOG(ERROR) << "ApplyKernel has no input_names, op[" << op_name << "]."; | MS_LOG(ERROR) << "ApplyKernel has no input_names, op[" << op_name << "]."; | ||||
| return false; | return false; | ||||
| @@ -465,12 +547,12 @@ KernelPackPtr AkgKernelBuild::OpBuild(const std::string &node_json, const AnfNod | |||||
| (void)alarm(0); | (void)alarm(0); | ||||
| if (pRes == nullptr) { | if (pRes == nullptr) { | ||||
| MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" | MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" | ||||
| << PyObjectToStr(pArg) << ")."; | |||||
| << AkgKernelBuild::PyObjectToStr(pArg) << ")."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (PyObject_IsTrue(pRes) != 1) { | if (PyObject_IsTrue(pRes) != 1) { | ||||
| MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" | MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" | ||||
| << PyObjectToStr(pArg) << ")."; | |||||
| << AkgKernelBuild::PyObjectToStr(pArg) << ")."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -513,5 +595,29 @@ KernelPackPtr AkgKernelBuild::BuildByJson(const AnfNodePtr &anf_node, std::vecto | |||||
| << "]"; | << "]"; | ||||
| return kernel_pack; | return kernel_pack; | ||||
| } | } | ||||
| size_t AkgKernelBuild::GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| auto cnode = anf_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (input_idx + 1 >= cnode->inputs().size()) { | |||||
| MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of [" | |||||
| << cnode->inputs().size() - 1 << "][" << cnode->DebugString() << "]"; | |||||
| } | |||||
| auto input_node = cnode->input(input_idx + 1); | |||||
| if (input_tensor_idx_.find(input_node) == input_tensor_idx_.end()) { | |||||
| size_t index = input_tensor_idx_.size(); | |||||
| input_tensor_idx_[input_node] = index; | |||||
| } | |||||
| return input_tensor_idx_[input_node]; | |||||
| } | |||||
| size_t AkgKernelBuild::GetOutputTensorIdxInc() { | |||||
| size_t idx = output_tensor_idx_++; | |||||
| return idx; | |||||
| } | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -32,29 +32,45 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| class AkgKernelBuild { | class AkgKernelBuild { | ||||
| public: | public: | ||||
| AkgKernelBuild() = default; | |||||
| AkgKernelBuild() { | |||||
| input_tensor_idx_ = {}; | |||||
| output_tensor_idx_ = 0; | |||||
| } | |||||
| ~AkgKernelBuild() = default; | ~AkgKernelBuild() = default; | ||||
| KernelPackPtr BuildByJson(const AnfNodePtr &anf_node, std::vector<size_t> *const input_size, | KernelPackPtr BuildByJson(const AnfNodePtr &anf_node, std::vector<size_t> *const input_size, | ||||
| std::vector<size_t> *const output_size); | std::vector<size_t> *const output_size); | ||||
| static std::string GetProcessor(const AnfNodePtr &anf_node); | |||||
| static std::string PyObjectToStr(PyObject *const PyObj); | |||||
| private: | |||||
| protected: | |||||
| bool CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json); | bool CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json); | ||||
| bool CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json); | bool CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json); | ||||
| bool CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name, | bool CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name, | ||||
| const std::shared_ptr<OpInfo> &op_info, nlohmann::json *const attrs_json); | const std::shared_ptr<OpInfo> &op_info, nlohmann::json *const attrs_json); | ||||
| KernelPackPtr OpBuild(const std::string &node_json, const AnfNodePtr &anf_node); | |||||
| int GetOpCntInc(); | |||||
| size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx); | |||||
| size_t GetOutputTensorIdxInc(); | |||||
| bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name, | bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name, | ||||
| nlohmann::json *const node_json); | nlohmann::json *const node_json); | ||||
| KernelPackPtr OpBuild(const std::string &node_json, const AnfNodePtr &anf_node); | |||||
| int GetOpCntInc(); | |||||
| std::string GetProcessor(const AnfNodePtr &anf_node); | |||||
| static int op_cnt_; | static int op_cnt_; | ||||
| // lock for variable fusionOpCnt in singleton mode | // lock for variable fusionOpCnt in singleton mode | ||||
| static std::mutex op_cnt_mtx_; | static std::mutex op_cnt_mtx_; | ||||
| std::string json_name_; | std::string json_name_; | ||||
| std::string json_info_; | std::string json_info_; | ||||
| std::unordered_map<AnfNodePtr, size_t> input_tensor_idx_; | |||||
| size_t output_tensor_idx_; | |||||
| }; | }; | ||||
| bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size, | |||||
| std::vector<size_t> *const output_size); | |||||
| void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position, | |||||
| nlohmann::json *const node_json); | |||||
| std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag, | |||||
| const std::pair<size_t, size_t> &position); | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,50 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/akg/akg_kernel_metadata.h" | |||||
| #include <memory> | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "kernel/oplib/oplib.h" | |||||
| #include "kernel/common_utils.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| void AkgMetadataInfo(const CNodePtr &kernel_node, | |||||
| std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||||
| std::string op_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| for (size_t i = 0; i < support_devices.size(); i++) { | |||||
| auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); | |||||
| if (op_info_ptr == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if (!ParseMetadata(kernel_node, op_info_ptr, Processor(i), kernel_info_list)) { | |||||
| MS_LOG(WARNING) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "] failed."; | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "]."; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (kernel_info_list->empty()) { | |||||
| MS_LOG(WARNING) << "Akg dose not has metadata of op[" << op_name << "]."; | |||||
| } | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <unordered_map> | |||||
| #include <memory> | |||||
| #include "kernel/kernel_build_info.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| void AkgMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list); | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ | |||||
| @@ -0,0 +1,385 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/akg/ascend/akg_ascend_kernel_build.h" | |||||
| #include <algorithm> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_set> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include <Python.h> | |||||
| #include "ir/dtype.h" | |||||
| #include "ir/func_graph.h" | |||||
| #include "kernel/kernel.h" | |||||
| #include "kernel/common_utils.h" | |||||
| #include "kernel/tbe/tbe_utils.h" | |||||
| #include "kernel/akg/ascend/akg_ascend_kernel_mod.h" | |||||
| #include "kernel/akg/akg_kernel_attrs_process.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| constexpr int32_t PARALLEL_ARGS_SIZE = 3; | |||||
| constexpr int32_t PROCESS_NUM = 16; | |||||
| constexpr int32_t TIME_OUT = 300; | |||||
| constexpr auto kOpDesc = "op_desc"; | |||||
| constexpr auto kShape = "shape"; | |||||
| constexpr auto kDataType = "data_type"; | |||||
| constexpr auto kInputDesc = "input_desc"; | |||||
| constexpr auto kOutputDesc = "output_desc"; | |||||
| constexpr auto kTensorName = "tensor_name"; | |||||
| constexpr auto kCompileAkgKernelParallelFunc = "compile_akg_kernel_parallel"; | |||||
| constexpr auto kMultiProcModule = "mindspore._extends.parallel_compile.akg_compiler.multi_process_compiler"; | |||||
| bool AkgAscendKernelBuilder::CollectJson(const AnfNodePtr &anf_node) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| std::string op_name = AnfAlgo::GetCNodeName(anf_node); | |||||
| MS_LOG(INFO) << "AKG start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]"; | |||||
| auto it = kAkgKernelAttrsProcessMap.find(op_name); | |||||
| if (it != kAkgKernelAttrsProcessMap.end()) { | |||||
| it->second(anf_node); | |||||
| } | |||||
| MS_LOG(INFO) << "Akg start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]"; | |||||
| nlohmann::json node_json; | |||||
| if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) { | |||||
| MS_LOG(ERROR) << "Op[" << op_name << "] create single kernel json failed."; | |||||
| } | |||||
| kernel_json_ = node_json.dump(); | |||||
| if (!GetIOSize(node_json, &input_size_list_, &output_size_list_)) { | |||||
| MS_LOG(ERROR) << "Cal mem size failed."; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, | |||||
| const std::vector<AnfNodePtr> &input_list, | |||||
| const std::vector<AnfNodePtr> &output_list) { | |||||
| if (anf_nodes.empty() || input_list.empty()) { | |||||
| MS_LOG(ERROR) << "Invalid input size, anf_nodes [" << anf_nodes.size() << "], input_list [" << input_list.size() | |||||
| << "]."; | |||||
| return false; | |||||
| } | |||||
| MS_LOG(INFO) << "anf_nodes [" << output_list.size() << "], input_list [" << anf_nodes.size() << "], output_list [" | |||||
| << input_list.size() << "]."; | |||||
| std::map<AnfNodePtr, nlohmann::json> node_json_map; | |||||
| for (auto const &anf_node : anf_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| std::string op_name = AnfAlgo::GetCNodeName(anf_node); | |||||
| if (!AnfAlgo::IsRealKernel(anf_node)) { | |||||
| MS_LOG(ERROR) << "Invalid anf node to build [" << anf_node->fullname_with_scope() << "]."; | |||||
| return false; | |||||
| } | |||||
| auto it = kAkgKernelAttrsProcessMap.find(op_name); | |||||
| if (it != kAkgKernelAttrsProcessMap.end()) { | |||||
| it->second(anf_node); | |||||
| } | |||||
| nlohmann::json node_json; | |||||
| if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) { | |||||
| MS_LOG(ERROR) << "Op [" << op_name << "] create single kernel json failed."; | |||||
| return false; | |||||
| } | |||||
| // No need for composite op. | |||||
| node_json.erase("id"); | |||||
| node_json.erase("op"); | |||||
| node_json.erase("composite"); | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| if (primitive->GetAttr("fusion") != nullptr) { | |||||
| node_json["fusion"] = primitive->GetAttr("fusion")->ToString(); | |||||
| } | |||||
| node_json_map[anf_node] = node_json; | |||||
| } | |||||
| for (auto const &anf_node : anf_nodes) { | |||||
| std::vector<int> dyn_input_sizes; | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { | |||||
| dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes)); | |||||
| } | |||||
| bool is_dynamic_input = !dyn_input_sizes.empty(); | |||||
| size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node); | |||||
| size_t real_input_index = 0; | |||||
| for (size_t i = 0; i < input_num; ++i) { | |||||
| size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1; | |||||
| for (size_t j = 0; j < input_tensor_num; ++j) { | |||||
| auto tmp_input = GetKernelInput(anf_node, real_input_index); | |||||
| std::string tensor_name = GetTensorName(node_json_map[anf_node], kInputDesc, std::make_pair(i, j)); | |||||
| if (node_json_map.find(tmp_input.first) != node_json_map.end()) { | |||||
| std::string new_tensor_name = | |||||
| GetTensorName(node_json_map[tmp_input.first], kOutputDesc, std::make_pair(0, tmp_input.second)); | |||||
| SetTensorName(kInputDesc, new_tensor_name, std::make_pair(i, j), &(node_json_map[anf_node])); | |||||
| MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of [" | |||||
| << anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output [" | |||||
| << new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "]."; | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "[" << real_input_index << "] input " << tensor_name << "] of [" | |||||
| << anf_node->fullname_with_scope() << "] is out input."; | |||||
| } | |||||
| real_input_index++; | |||||
| } | |||||
| } | |||||
| } | |||||
| nlohmann::json fused_node_json; | |||||
| std::vector<nlohmann::json> node_json_desc; | |||||
| std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc), | |||||
| [&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; }); | |||||
| fused_node_json[kOpDesc] = node_json_desc; | |||||
| nlohmann::json inputs_json; | |||||
| auto input_index = GetInputIndex(anf_nodes, input_list); | |||||
| for (size_t i = 0; i < input_index.size(); ++i) { | |||||
| auto tmp_input = input_index[i]; | |||||
| auto type_id = AnfAlgo::GetInputDeviceDataType(tmp_input.first, tmp_input.second.first); | |||||
| std::string dtype = TypeId2String(type_id); | |||||
| nlohmann::json input_desc_json; | |||||
| input_desc_json[kTensorName] = GetTensorName(node_json_map[tmp_input.first], kInputDesc, tmp_input.second); | |||||
| input_desc_json[kDataType] = dtype; | |||||
| input_desc_json[kShape] = AnfAlgo::GetInputDeviceShape(tmp_input.first, tmp_input.second.first); | |||||
| inputs_json.emplace_back(std::vector<nlohmann::json>{input_desc_json}); | |||||
| } | |||||
| fused_node_json[kInputDesc] = inputs_json; | |||||
| nlohmann::json outputs_json; | |||||
| auto output_index = GetOutputIndex(anf_nodes, input_list, output_list); | |||||
| for (size_t i = 0; i < output_index.size(); ++i) { | |||||
| auto tmp_output = output_index[i]; | |||||
| bool found = false; | |||||
| nlohmann::json output_desc_json; | |||||
| for (size_t input_i = 0; input_i < input_list.size(); ++input_i) { | |||||
| if (tmp_output.first == input_list[input_i]) { | |||||
| output_desc_json = inputs_json[input_i][0]; | |||||
| found = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (!found) { | |||||
| auto type_id = AnfAlgo::GetOutputDeviceDataType(tmp_output.first, tmp_output.second); | |||||
| std::string dtype = TypeId2String(type_id); | |||||
| output_desc_json[kTensorName] = | |||||
| GetTensorName(node_json_map[tmp_output.first], kOutputDesc, std::make_pair(0, tmp_output.second)); | |||||
| output_desc_json[kDataType] = dtype; | |||||
| auto output_shape = AnfAlgo::GetOutputDeviceShape(tmp_output.first, tmp_output.second); | |||||
| if (output_shape.empty()) { | |||||
| output_shape.push_back(1); | |||||
| } | |||||
| output_desc_json[kShape] = output_shape; | |||||
| } | |||||
| outputs_json.emplace_back(output_desc_json); | |||||
| } | |||||
| fused_node_json[kOutputDesc] = outputs_json; | |||||
| size_t hash_id = std::hash<std::string>()(fused_node_json.dump()); | |||||
| json_name_ = "Fused_"; | |||||
| auto fg = anf_nodes[0]->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(fg); | |||||
| auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||||
| if (attr_val != nullptr) { | |||||
| auto fg_attr = GetValue<std::string>(attr_val); | |||||
| (void)json_name_.append(fg_attr).append("_"); | |||||
| } | |||||
| (void)json_name_.append(std::to_string(hash_id)); | |||||
| fused_node_json["composite_graph"] = fg->ToString(); | |||||
| fused_node_json["op"] = json_name_; | |||||
| fused_node_json["platform"] = "AKG"; | |||||
| fused_node_json["process"] = "aicore"; | |||||
| fused_node_json["composite"] = true; | |||||
| kernel_json_ = fused_node_json.dump(); | |||||
| if (!GetIOSize(fused_node_json, &input_size_list_, &output_size_list_)) { | |||||
| MS_LOG(ERROR) << "Cal mem size failed."; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void GenParallelCompileFuncArgs(const std::vector<std::string> &kernel_jsons, PyObject **p_args) { | |||||
| MS_EXCEPTION_IF_NULL(p_args); | |||||
| *p_args = PyTuple_New(PARALLEL_ARGS_SIZE); | |||||
| PyObject *arg1 = PyList_New(kernel_jsons.size()); | |||||
| for (int i = 0; i < PyList_Size(arg1); ++i) { | |||||
| PyList_SetItem(arg1, i, Py_BuildValue("s", kernel_jsons[i].c_str())); | |||||
| } | |||||
| PyObject *arg2 = Py_BuildValue("i", PROCESS_NUM); | |||||
| PyObject *arg3 = Py_BuildValue("i", TIME_OUT); | |||||
| (void)PyTuple_SetItem(*p_args, 0, arg1); | |||||
| (void)PyTuple_SetItem(*p_args, 1, arg2); | |||||
| (void)PyTuple_SetItem(*p_args, 2, arg3); | |||||
| } | |||||
| bool AkgOpParallelBuild(const std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> &build_args) { | |||||
| // Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess. | |||||
| std::vector<std::string> jsons; | |||||
| std::unordered_set<std::string> json_name_set; | |||||
| std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> repeat_nodes; | |||||
| for (const auto &[builder, anf_node] : build_args) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| auto json_name = builder.json_name(); | |||||
| MS_LOG(DEBUG) << "Akg start compile op: " << json_name; | |||||
| auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); | |||||
| if (cached_kernel_pack != nullptr) { | |||||
| MS_LOG(DEBUG) << "Use cached kernel, json_name_[" << json_name << "], fullname_with_scope[" | |||||
| << anf_node->fullname_with_scope() << "]."; | |||||
| auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack); | |||||
| kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); | |||||
| kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); | |||||
| AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); | |||||
| continue; | |||||
| } | |||||
| if (json_name_set.count(json_name) != 0) { | |||||
| repeat_nodes.push_back({builder, anf_node}); | |||||
| continue; | |||||
| } | |||||
| json_name_set.insert(json_name); | |||||
| auto node_json = builder.kernel_json(); | |||||
| kernel::SaveJsonInfo(json_name, node_json); | |||||
| jsons.push_back(node_json); | |||||
| } | |||||
| // No nodes need to be compiled! | |||||
| if (jsons.empty()) { | |||||
| return true; | |||||
| } | |||||
| // Try to call python method to compile nodes parallely. | |||||
| PyObject *p_module = nullptr; | |||||
| PyObject *p_func = nullptr; | |||||
| PyObject *p_arg = nullptr; | |||||
| PyObject *p_res = nullptr; | |||||
| p_module = PyImport_ImportModule(kMultiProcModule); | |||||
| if (p_module == nullptr) { | |||||
| MS_LOG(ERROR) << "Failed to import [" << kMultiProcModule << "]."; | |||||
| return false; | |||||
| } | |||||
| p_func = PyObject_GetAttrString(p_module, kCompileAkgKernelParallelFunc); | |||||
| GenParallelCompileFuncArgs(jsons, &p_arg); | |||||
| MS_LOG(DEBUG) << "Call function [" << kCompileAkgKernelParallelFunc << "], try to compile " << jsons.size() | |||||
| << " Akg kernels parallelly."; | |||||
| p_res = PyEval_CallObject(p_func, p_arg); | |||||
| if (p_res == nullptr) { | |||||
| PyErr_Print(); | |||||
| MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n(" | |||||
| << AkgKernelBuild::PyObjectToStr(p_arg) << ")."; | |||||
| return false; | |||||
| } | |||||
| if (PyObject_IsTrue(p_res) != 1) { | |||||
| PyErr_Print(); | |||||
| MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n(" | |||||
| << AkgKernelBuild::PyObjectToStr(p_arg) << ")."; | |||||
| return false; | |||||
| } | |||||
| // All unique done here, cache them and set kernel. | |||||
| for (const auto &[builder, anf_node] : build_args) { | |||||
| auto json_name = builder.json_name(); | |||||
| auto new_kernel_pack = tbe::TbeUtils::InsertCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); | |||||
| if (new_kernel_pack == nullptr) { | |||||
| MS_LOG(ERROR) << "Insert to cache failed, json_name_[" << json_name << "], fullname_with_scope[" | |||||
| << anf_node->fullname_with_scope() << "]."; | |||||
| return false; | |||||
| } | |||||
| auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(new_kernel_pack); | |||||
| kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); | |||||
| kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); | |||||
| AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); | |||||
| MS_LOG(DEBUG) << "Akg compile " << json_name << " kernel and insert cache successfully!"; | |||||
| } | |||||
| // Handle repeated nodes. | |||||
| for (const auto &[builder, anf_node] : repeat_nodes) { | |||||
| auto node_json = builder.kernel_json(); | |||||
| auto json_name = builder.json_name(); | |||||
| auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); | |||||
| if (cached_kernel_pack == nullptr) return false; | |||||
| MS_LOG(INFO) << "Use just compiled kernel, json_name_[" << json_name << "], fullname_with_scope[" | |||||
| << anf_node->fullname_with_scope() << "]."; | |||||
| auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack); | |||||
| kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); | |||||
| kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); | |||||
| AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) { | |||||
| std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> json_and_node; | |||||
| for (const auto &anf_node : anf_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| AkgAscendKernelBuilder akg_cce_kernel_builder; | |||||
| KernelPackPtr kernel_pack = nullptr; | |||||
| auto cnode = anf_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (AnfAlgo::IsGraphKernel(cnode)) { | |||||
| auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); | |||||
| auto mng = func_graph->manager(); | |||||
| if (mng == nullptr) { | |||||
| mng = Manage(func_graph, true); | |||||
| func_graph->set_manager(mng); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| std::vector<AnfNodePtr> node_list; | |||||
| std::vector<AnfNodePtr> input_list; | |||||
| std::vector<AnfNodePtr> output_list; | |||||
| std::string op_name = AnfAlgo::GetCNodeName(anf_node); | |||||
| MS_LOG(INFO) << "Akg start compile composite op[" << op_name << "]"; | |||||
| GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); | |||||
| if (!akg_cce_kernel_builder.CollectFusedJson(node_list, input_list, output_list)) { | |||||
| MS_EXCEPTION(UnknownError) << "Akg build failed composite op[" << op_name << "]."; | |||||
| } | |||||
| } else { | |||||
| if (!akg_cce_kernel_builder.CollectJson(anf_node)) { | |||||
| MS_EXCEPTION(UnknownError) << "Akg build failed op[" << AnfAlgo::GetCNodeName(anf_node) << "]."; | |||||
| } | |||||
| } | |||||
| json_and_node.push_back({akg_cce_kernel_builder, anf_node}); | |||||
| } | |||||
| if (json_and_node.empty()) { | |||||
| MS_LOG(DEBUG) << "There is no kernel needed to be compiled."; | |||||
| return true; | |||||
| } | |||||
| return AkgOpParallelBuild(json_and_node); | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,52 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "ir/anf.h" | |||||
| #include "kernel/kernel.h" | |||||
| #include "kernel/akg/akg_kernel_build.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| class AkgAscendKernelBuilder : public AkgKernelBuild { | |||||
| public: | |||||
| AkgAscendKernelBuilder() = default; | |||||
| ~AkgAscendKernelBuilder() = default; | |||||
| bool CollectJson(const AnfNodePtr &anf_node); | |||||
| bool CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list, | |||||
| const std::vector<AnfNodePtr> &output_list); | |||||
| std::string json_name() const { return json_name_; } | |||||
| std::string kernel_json() const { return kernel_json_; } | |||||
| const std::vector<size_t> &input_size_list() const { return input_size_list_; } | |||||
| const std::vector<size_t> &output_size_list() const { return output_size_list_; } | |||||
| private: | |||||
| std::string kernel_json_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| }; | |||||
| bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes); | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ | |||||
| @@ -0,0 +1,181 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/akg/ascend/akg_ascend_kernel_mod.h" | |||||
| #include <algorithm> | |||||
| #include <fstream> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <mutex> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include "nlohmann/json.hpp" | |||||
| #include "runtime/rt.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "utils/convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| using std::fstream; | |||||
| using std::map; | |||||
| using std::mutex; | |||||
| using std::string; | |||||
| using TbeTaskInfoPtr = std::shared_ptr<ge::model_runner::TbeTaskInfo>; | |||||
| using tbe::KernelManager; | |||||
| constexpr uint32_t DEFAULT_BLOCK_DIM = 1; | |||||
| /** | |||||
| * @brief infotable contain func_stub\blockdim\kernel file buffer | |||||
| */ | |||||
| AkgKernelMod::AkgKernelMod(const KernelPackPtr &kernel_pack) : kernel_pack_(kernel_pack) {} | |||||
| void AkgKernelMod::SetInputSizeList(const std::vector<size_t> &size_list) { input_size_list_ = size_list; } | |||||
| void AkgKernelMod::SetOutputSizeList(const std::vector<size_t> &size_list) { output_size_list_ = size_list; } | |||||
| void AkgKernelMod::SetWorkspaceSizeList(const std::vector<size_t> &size_list) { workspace_size_list_ = size_list; } | |||||
| const std::vector<size_t> &AkgKernelMod::GetInputSizeList() const { return input_size_list_; } | |||||
| const std::vector<size_t> &AkgKernelMod::GetOutputSizeList() const { return output_size_list_; } | |||||
| const std::vector<size_t> &AkgKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; } | |||||
| void DumpData(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||||
| const char *dump_data = getenv("MS_KERNEL_DUMP_DATA"); | |||||
| if (dump_data) { | |||||
| int idx = 0; | |||||
| for (const auto &x : inputs) { | |||||
| std::vector<char> buf(x->size); | |||||
| if (RT_ERROR_NONE != rtMemcpy(buf.data(), buf.size(), reinterpret_cast<const void *>(x->addr), x->size, | |||||
| RT_MEMCPY_DEVICE_TO_HOST)) { | |||||
| MS_LOG(WARNING) << "Call runtime rtMemcpy error."; | |||||
| return; | |||||
| } | |||||
| std::string file_name("input_"); | |||||
| file_name += std::to_string(idx); | |||||
| std::ofstream file(file_name, std::ios::binary); | |||||
| if (file.is_open()) { | |||||
| (void)file.write(buf.data(), SizeToLong(buf.size())); | |||||
| file.close(); | |||||
| idx++; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Open file failed."; | |||||
| return; | |||||
| } | |||||
| } | |||||
| idx = 0; | |||||
| for (const auto &x : outputs) { | |||||
| std::vector<char> buf(x->size); | |||||
| if (RT_ERROR_NONE != rtMemcpy(buf.data(), buf.size(), reinterpret_cast<const void *>(x->addr), x->size, | |||||
| RT_MEMCPY_DEVICE_TO_HOST)) { | |||||
| MS_LOG(WARNING) << "Call runtime rtMemcpy error."; | |||||
| return; | |||||
| } | |||||
| std::string file_name("output_"); | |||||
| file_name += std::to_string(idx); | |||||
| std::ofstream file(file_name, std::ios::binary); | |||||
| if (file.is_open()) { | |||||
| (void)file.write(buf.data(), SizeToLong(buf.size())); | |||||
| file.close(); | |||||
| idx++; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Open file failed."; | |||||
| return; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| bool AkgKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | |||||
| if (stream_ptr == 0) { | |||||
| MS_LOG(ERROR) << "stream_ptr should not be nullptr."; | |||||
| return false; | |||||
| } | |||||
| if (kernel_pack_ == nullptr) { | |||||
| MS_LOG(ERROR) << "kernel pack should not be nullptr."; | |||||
| return false; | |||||
| } | |||||
| uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1. | |||||
| auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim); | |||||
| if (func_stub == 0) { | |||||
| MS_LOG(ERROR) << "GenFuncStub failed."; | |||||
| return false; | |||||
| } | |||||
| // pack all addresses into a vector. | |||||
| std::vector<void *> runtime_args; | |||||
| (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtime_args), | |||||
| [](const AddressPtr &input) -> void * { return input->addr; }); | |||||
| (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtime_args), | |||||
| [](const AddressPtr &output) -> void * { return output->addr; }); | |||||
| rtL2Ctrl_t *l2ctrl = nullptr; | |||||
| auto stream = reinterpret_cast<rtStream_t *>(stream_ptr); | |||||
| if (RT_ERROR_NONE != rtKernelLaunch(reinterpret_cast<void *>(func_stub), block_dim, runtime_args.data(), | |||||
| SizeToUint(sizeof(void *) * runtime_args.size()), l2ctrl, stream)) { | |||||
| MS_LOG(ERROR) << "Call runtime rtKernelLaunch error."; | |||||
| return false; | |||||
| } | |||||
| DumpData(inputs, outputs); | |||||
| return true; | |||||
| } | |||||
| std::vector<TaskInfoPtr> AkgKernelMod::GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &outputs, uint32_t stream_id) { | |||||
| if (kernel_pack_ == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "kernel pack should not be nullptr."; | |||||
| } | |||||
| std::vector<uint8_t> args; | |||||
| uint32_t args_size = 0; | |||||
| std::vector<uint8_t> sm_desc; | |||||
| void *binary = nullptr; | |||||
| uint32_t binary_size = 0; | |||||
| std::vector<uint8_t> meta_data; | |||||
| std::vector<void *> input_data_addrs; | |||||
| std::vector<void *> output_data_addrs; | |||||
| std::vector<void *> workspace_addrs; | |||||
| // pack all addresses into a vector. | |||||
| (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs), | |||||
| [](const AddressPtr &input) -> void * { return input->addr; }); | |||||
| (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), | |||||
| [](const AddressPtr &output) -> void * { return output->addr; }); | |||||
| uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1. | |||||
| auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim); | |||||
| if (func_stub == 0) { | |||||
| MS_LOG(EXCEPTION) << "GenFuncStub failed."; | |||||
| } | |||||
| std::string stub_func = KernelManager::GetStubFuncName(kernel_pack_); | |||||
| MS_LOG(DEBUG) << "The block_dim is:" << block_dim; | |||||
| TbeTaskInfoPtr task_info_ptr = make_shared<ge::model_runner::TbeTaskInfo>( | |||||
| stream_id, stub_func, block_dim, args, args_size, sm_desc, binary, binary_size, meta_data, input_data_addrs, | |||||
| output_data_addrs, workspace_addrs); | |||||
| return {task_info_ptr}; | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "kernel/ascend_kernel_mod.h" | |||||
| #include "kernel/tbe/tbe_utils.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| class AkgKernelMod : public AscendKernelMod { | |||||
| public: | |||||
| explicit AkgKernelMod(const KernelPackPtr &kernel_pack); | |||||
| ~AkgKernelMod() final {} | |||||
| void SetInputSizeList(const std::vector<size_t> &size_list); | |||||
| void SetOutputSizeList(const std::vector<size_t> &size_list); | |||||
| void SetWorkspaceSizeList(const std::vector<size_t> &size_list); | |||||
| const std::vector<size_t> &GetInputSizeList() const override; | |||||
| const std::vector<size_t> &GetOutputSizeList() const override; | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override; | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override; | |||||
| std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, uint32_t stream_id) override; | |||||
| private: | |||||
| KernelPackPtr kernel_pack_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| using AkgKernelModPtr = std::shared_ptr<AkgKernelMod>; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ | |||||
| @@ -18,7 +18,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "kernel/kernel.h" | #include "kernel/kernel.h" | ||||
| #include "kernel/akg/akgkernelbuild.h" | |||||
| #include "kernel/akg/akg_kernel_build.h" | |||||
| #include "kernel/akg/gpu/akg_gpu_kernel_mod.h" | #include "kernel/akg/gpu/akg_gpu_kernel_mod.h" | ||||
| #include "common/utils.h" | #include "common/utils.h" | ||||
| @@ -23,6 +23,11 @@ | |||||
| #include "nlohmann/json.hpp" | #include "nlohmann/json.hpp" | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "common/utils.h" | #include "common/utils.h" | ||||
| #include "ir/manager.h" | |||||
| #include "ir/meta_tensor.h" | |||||
| #include "ir/func_graph.h" | |||||
| #include "operator/ops.h" | |||||
| #include "utils/graph_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| @@ -48,12 +53,6 @@ const std::map<TypeId, std::string> type_id_str_map = { | |||||
| {TypeId::kNumberTypeBool, "bool"}, | {TypeId::kNumberTypeBool, "bool"}, | ||||
| }; | }; | ||||
| const std::map<std::string, std::string> DATATYPE_STRING_MAP{ | |||||
| {"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"}, | |||||
| {"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"}, | |||||
| {"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool_", "bool"}, {"Float64", "double"}, | |||||
| }; | |||||
| const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = { | const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = { | ||||
| {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"}, | {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"}, | ||||
| {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"}, | {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"}, | ||||
| @@ -243,14 +242,6 @@ TypeId DtypeToTypeId(const std::string &dtypes) { | |||||
| } | } | ||||
| } | } | ||||
| std::string Dtype2String(const std::string &dtypes) { | |||||
| auto iter = DATATYPE_STRING_MAP.find(dtypes); | |||||
| if (iter == DATATYPE_STRING_MAP.end()) { | |||||
| MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes; | |||||
| } | |||||
| return iter->second; | |||||
| } | |||||
| std::string TypeId2String(TypeId type_id) { | std::string TypeId2String(TypeId type_id) { | ||||
| auto iter = type_id_str_map.find(type_id); | auto iter = type_id_str_map.find(type_id); | ||||
| if (iter == type_id_str_map.end()) { | if (iter == type_id_str_map.end()) { | ||||
| @@ -361,7 +352,7 @@ bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou | |||||
| output_num = 1; | output_num = 1; | ||||
| } else { | } else { | ||||
| if (output_idx < real_output_num) { | if (output_idx < real_output_num) { | ||||
| MS_LOG(INFO) << "Set output kernel builder info, output type is optional, output index is :" << output_idx; | |||||
| MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx; | |||||
| output_num = 1; | output_num = 1; | ||||
| } | } | ||||
| } | } | ||||
| @@ -403,7 +394,7 @@ void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBu | |||||
| } | } | ||||
| if (imply_type == kAKG) { | if (imply_type == kAKG) { | ||||
| builder->SetKernelType(AUTO_DIFF_KERNEL); | |||||
| builder->SetKernelType(AKG_KERNEL); | |||||
| } else if (imply_type == kAICPU) { | } else if (imply_type == kAICPU) { | ||||
| builder->SetKernelType(AICPU_KERNEL); | builder->SetKernelType(AICPU_KERNEL); | ||||
| } else { | } else { | ||||
| @@ -634,5 +625,256 @@ void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradie | |||||
| } | } | ||||
| unique_grad->indices_size_ = unique_indices_size + 1; | unique_grad->indices_size_ = unique_indices_size + 1; | ||||
| } | } | ||||
| std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| if (index >= AnfAlgo::GetInputTensorNum(anf_node)) { | |||||
| MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs."; | |||||
| } | |||||
| auto cnode = anf_node->cast<CNodePtr>(); | |||||
| if (cnode == nullptr) { | |||||
| return AnfAlgo::VisitKernel(anf_node, 0); | |||||
| } else { | |||||
| return AnfAlgo::VisitKernel(anf_node->cast<CNodePtr>()->input(index + 1), 0); | |||||
| } | |||||
| } | |||||
| std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list, | |||||
| const std::vector<AnfNodePtr> &input_list) { | |||||
| std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> input_index; | |||||
| for (size_t i = 0; i < input_list.size(); ++i) { | |||||
| auto const &input = input_list[i]; | |||||
| MS_EXCEPTION_IF_NULL(input); | |||||
| bool found = false; | |||||
| // using NodeUsersMap = std::unordered_map<AnfNodePtr, std::set<std::pair<AnfNodePtr, int>>>; | |||||
| auto mng = input->func_graph()->manager(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| const NodeUsersMap &users = mng->node_users(); | |||||
| auto input_users = users.find(input); | |||||
| if (input_users == users.end() || input_users->second.empty()) { | |||||
| MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of [" | |||||
| << input->func_graph()->ToString() << "] has no users."; | |||||
| } | |||||
| for (auto const &input_user : input_users->second) { | |||||
| for (auto const &anf_node : node_list) { | |||||
| if (anf_node != input_user.first) { | |||||
| continue; | |||||
| } | |||||
| std::vector<int> dyn_input_sizes; | |||||
| auto prim = AnfAlgo::GetCNodePrimitive(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| if (prim->GetAttr(kAttrDynInputSizes) != nullptr) { | |||||
| dyn_input_sizes = GetValue<const std::vector<int>>(prim->GetAttr(kAttrDynInputSizes)); | |||||
| } | |||||
| if (dyn_input_sizes.empty()) { | |||||
| input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0))); | |||||
| found = true; | |||||
| break; | |||||
| } else { | |||||
| int used_as_idx = input_user.second - 1; | |||||
| int accum_idx = 0; | |||||
| size_t dyn_i = 0; | |||||
| for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) { | |||||
| accum_idx += dyn_input_sizes[dyn_i]; | |||||
| if (used_as_idx < accum_idx) { | |||||
| input_index.push_back(std::make_pair( | |||||
| anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i]))))); | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (dyn_i != dyn_input_sizes.size()) { | |||||
| found = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (found) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (!found) { | |||||
| MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of [" | |||||
| << input->func_graph()->ToString() << "] found no related kernel info."; | |||||
| } | |||||
| } | |||||
| return input_index; | |||||
| } | |||||
| std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list, | |||||
| const std::vector<AnfNodePtr> &input_list, | |||||
| const std::vector<AnfNodePtr> &output_list) { | |||||
| std::vector<std::pair<AnfNodePtr, size_t>> output_index; | |||||
| for (size_t i = 0; i < output_list.size(); ++i) { | |||||
| auto const &output = output_list[i]; | |||||
| MS_EXCEPTION_IF_NULL(output); | |||||
| bool found = false; | |||||
| auto pree_node = AnfAlgo::VisitKernel(output, 0); | |||||
| auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first); | |||||
| if (pos != std::end(node_list)) { | |||||
| output_index.push_back(pree_node); | |||||
| continue; | |||||
| } | |||||
| auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first); | |||||
| if (ret != std::end(input_list)) { | |||||
| output_index.push_back(std::make_pair(pree_node.first, 0)); | |||||
| found = true; | |||||
| } | |||||
| if (!found) { | |||||
| MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of [" | |||||
| << output->func_graph()->ToString() << "] found no related kernel info."; | |||||
| } | |||||
| } | |||||
| return output_index; | |||||
| } | |||||
| void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) { | |||||
| MS_EXCEPTION_IF_NULL(node_list); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return()); | |||||
| for (auto const &node : node_lists) { | |||||
| if (!AnfAlgo::IsRealKernel(node) || !node->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) { | |||||
| node_list->push_back(node); | |||||
| } | |||||
| } | |||||
| } | |||||
| void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list, | |||||
| std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) { | |||||
| MS_EXCEPTION_IF_NULL(node_list); | |||||
| MS_EXCEPTION_IF_NULL(input_list); | |||||
| MS_EXCEPTION_IF_NULL(output_list); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| GetValidKernelNodes(func_graph, node_list); | |||||
| auto parameters = func_graph->parameters(); | |||||
| input_list->insert(input_list->begin(), parameters.begin(), parameters.end()); | |||||
| auto func_output = func_graph->output(); | |||||
| MS_EXCEPTION_IF_NULL(func_output); | |||||
| if (func_output->isa<CNode>()) { | |||||
| // multi output. | |||||
| auto cnode = func_output->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto input0 = cnode->input(kAnfPrimitiveIndex); | |||||
| MS_EXCEPTION_IF_NULL(input0); | |||||
| if (IsPrimitive(input0, prim::kPrimMakeTuple)) { | |||||
| for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) { | |||||
| auto input_node = cnode->input(input_idx); | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first); | |||||
| } | |||||
| } else { | |||||
| // single output. | |||||
| output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first); | |||||
| } | |||||
| } else { | |||||
| // single output. | |||||
| output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first); | |||||
| } | |||||
| } | |||||
| bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(node_json); | |||||
| auto cnode = anf_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (input_idx + 1 >= cnode->size()) { | |||||
| MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of [" | |||||
| << cnode->inputs().size() << "][" << cnode->DebugString() << "]"; | |||||
| } | |||||
| auto input_node = cnode->input(input_idx + 1); | |||||
| if (!IsValueNode<tensor::Tensor>(input_node)) { | |||||
| return false; | |||||
| } | |||||
| auto tensor = GetValueNode<tensor::TensorPtr>(input_node); | |||||
| if (tensor == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto type_id = tensor->data_type(); | |||||
| auto *data = tensor->data_c(); | |||||
| MS_EXCEPTION_IF_NULL(data); | |||||
| if (tensor->DataDim() > 1 || tensor->DataSize() != 1) { | |||||
| // not const tensor. | |||||
| MS_LOG(WARNING) << "We take first value of tensor whose datasize != 1, [" << input_node->DebugString(2) << "]"; | |||||
| } | |||||
| if (type_id == kFloat32->type_id()) { | |||||
| float *val = static_cast<float *>(data); | |||||
| MS_EXCEPTION_IF_NULL(val); | |||||
| (*node_json)["value"] = val[0]; | |||||
| MS_LOG(DEBUG) << "Value of tensor[" << cnode->DebugString() << "] is [float32][" << *val << "]."; | |||||
| return true; | |||||
| } else if (type_id == kFloat16->type_id()) { | |||||
| float16 *val = static_cast<float16 *>(data); | |||||
| MS_EXCEPTION_IF_NULL(val); | |||||
| (*node_json)["value"] = static_cast<float>(val[0]); | |||||
| MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [float16][" << *val << "]."; | |||||
| return true; | |||||
| } else if (type_id == kInt32->type_id()) { | |||||
| int *val = static_cast<int *>(data); | |||||
| MS_EXCEPTION_IF_NULL(val); | |||||
| (*node_json)["value"] = val[0]; | |||||
| MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [int32][" << *val << "]."; | |||||
| return true; | |||||
| } | |||||
| MS_LOG(ERROR) << "Unknown value type of tensor[" << cnode->DebugString() << "]"; | |||||
| return false; | |||||
| } | |||||
| void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(node_list); | |||||
| auto output = func_graph->output(); | |||||
| MS_EXCEPTION_IF_NULL(output); | |||||
| if (AnfAlgo::IsRealKernel(output)) { | |||||
| // single output. | |||||
| node_list->push_back(std::make_pair(output, 0)); | |||||
| return; | |||||
| } else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { | |||||
| auto output_cnode = output->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(output_cnode); | |||||
| // multi output. | |||||
| auto &inputs = output_cnode->inputs(); | |||||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||||
| auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0); | |||||
| node_list->push_back(in_with_idx); | |||||
| } | |||||
| return; | |||||
| } | |||||
| MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2) | |||||
| << " of graph: " << func_graph->ToString(); | |||||
| } | |||||
| bool IsWeightBoundary(const AnfNodePtr &node) { | |||||
| if (node->isa<ValueNode>()) { | |||||
| return true; | |||||
| } | |||||
| if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,9 +20,12 @@ | |||||
| #include <dirent.h> | #include <dirent.h> | ||||
| #include <memory> | #include <memory> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <unordered_set> | |||||
| #include <map> | #include <map> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <utility> | |||||
| #include <nlohmann/json.hpp> | |||||
| #include "kernel/kernel.h" | #include "kernel/kernel.h" | ||||
| #include "kernel/oplib/opinfo.h" | #include "kernel/oplib/opinfo.h" | ||||
| #include "kernel/kernel_build_info.h" | #include "kernel/kernel_build_info.h" | ||||
| @@ -79,13 +82,11 @@ bool CheckCache(const std::string &kernel_name); | |||||
| KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor); | KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor); | ||||
| KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor); | KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor); | ||||
| TypeId DtypeToTypeId(const std::string &dtypes); | TypeId DtypeToTypeId(const std::string &dtypes); | ||||
| std::string Dtype2String(const std::string &dtypes); | |||||
| std::string Dtype2ShortType(const std::string &dtypes); | std::string Dtype2ShortType(const std::string &dtypes); | ||||
| std::string TypeId2String(TypeId type_id); | std::string TypeId2String(TypeId type_id); | ||||
| size_t GetDtypeNbyte(const std::string &dtypes); | size_t GetDtypeNbyte(const std::string &dtypes); | ||||
| bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor, | bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor, | ||||
| std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list); | std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list); | ||||
| bool IsAtomicNode(const CNodePtr &kernel_node); | |||||
| void SaveJsonInfo(const std::string &json_name, const std::string &info); | void SaveJsonInfo(const std::string &json_name, const std::string &info); | ||||
| std::string GetProcessor(const AnfNodePtr &anf_node); | std::string GetProcessor(const AnfNodePtr &anf_node); | ||||
| bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b); | bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b); | ||||
| @@ -94,6 +95,18 @@ void DeduplicateIndexedSlices(const SparseGradient &origin_sparse_grad, SparseGr | |||||
| size_t outer_dim); | size_t outer_dim); | ||||
| void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, | void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, | ||||
| size_t outer_dim); | size_t outer_dim); | ||||
| std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index); | |||||
| std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list, | |||||
| const std::vector<AnfNodePtr> &input_list); | |||||
| std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list, | |||||
| const std::vector<AnfNodePtr> &input_list, | |||||
| const std::vector<AnfNodePtr> &output_list); | |||||
| void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list, | |||||
| std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list); | |||||
| void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list); | |||||
| bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json); | |||||
| void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list); | |||||
| bool IsWeightBoundary(const AnfNodePtr &node); | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,7 +17,7 @@ | |||||
| #include <fstream> | #include <fstream> | ||||
| #include "mindspore/ccsrc/kernel/kernel.h" | #include "mindspore/ccsrc/kernel/kernel.h" | ||||
| #include "kernel/kernel.h" | #include "kernel/kernel.h" | ||||
| #include "kernel/akg/akgkernelbuild.h" | |||||
| #include "kernel/akg/akg_kernel_build.h" | |||||
| #include "nlohmann/json.hpp" | #include "nlohmann/json.hpp" | ||||
| #include "securec/include/securec.h" | #include "securec/include/securec.h" | ||||
| #include "pipeline/parse/python_adapter.h" | #include "pipeline/parse/python_adapter.h" | ||||
| @@ -27,7 +27,7 @@ | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AUTO_DIFF_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL }; | |||||
| enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AKG_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL }; | |||||
| namespace kernel { | namespace kernel { | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include "kernel/rts/rt_kernel_info.h" | #include "kernel/rts/rt_kernel_info.h" | ||||
| #include "kernel/hccl/hccl_kernel_metadata.h" | #include "kernel/hccl/hccl_kernel_metadata.h" | ||||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" | #include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" | ||||
| #include "kernel/akg/akg_kernel_metadata.h" | |||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -59,10 +60,14 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node, | |||||
| } | } | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | |||||
| void KernelQueryAll(const CNodePtr &kernel_node, | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | MS_EXCEPTION_IF_NULL(kernel_info_list); | ||||
| TbeMetadataInfo(kernel_node, kernel_info_list); | TbeMetadataInfo(kernel_node, kernel_info_list); | ||||
| if (kernel_info_list->empty()) { | if (kernel_info_list->empty()) { | ||||
| AicpuMetadataInfo(kernel_node, kernel_info_list); | AicpuMetadataInfo(kernel_node, kernel_info_list); | ||||
| if (!kernel_info_list->empty()) { | if (!kernel_info_list->empty()) { | ||||
| @@ -82,6 +87,28 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel | |||||
| if (kernel_info_list->empty()) { | if (kernel_info_list->empty()) { | ||||
| MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!"; | MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!"; | ||||
| } | } | ||||
| } | |||||
| void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list, | |||||
| KernelType kernel_type) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||||
| std::string op_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| switch (kernel_type) { | |||||
| case KernelType::AKG_KERNEL: | |||||
| AkgMetadataInfo(kernel_node, kernel_info_list); | |||||
| break; | |||||
| default: | |||||
| KernelQueryAll(kernel_node, kernel_info_list); | |||||
| break; | |||||
| } | |||||
| if (kernel_info_list->empty()) { | |||||
| MS_EXCEPTION(NotExistsError) << "Op[" << kernel_node->DebugString() << "] kernel query fail!"; | |||||
| } | |||||
| // check output | |||||
| FilterInvalidKernelInfo(kernel_node, kernel_info_list); | FilterInvalidKernelInfo(kernel_node, kernel_info_list); | ||||
| } | } | ||||
| @@ -25,7 +25,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); | |||||
| void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list, | |||||
| KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); | |||||
| void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); | void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); | ||||
| bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); | bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); | ||||
| bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); | bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); | ||||
| @@ -272,8 +272,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im | |||||
| auto context = MsContext::GetInstance(); | auto context = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context); | MS_EXCEPTION_IF_NULL(context); | ||||
| bool is_gpu = (context->device_target() == kGPUDevice); | bool is_gpu = (context->device_target() == kGPUDevice); | ||||
| if ((is_gpu && (imply_type == kTBE || imply_type == kAICPU)) || | |||||
| (!is_gpu && (imply_type != kTBE && imply_type != kAICPU))) { | |||||
| if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) { | |||||
| MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) | MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) | ||||
| << ", current op num: " << op_info_.size(); | << ", current op num: " << op_info_.size(); | ||||
| return nullptr; | return nullptr; | ||||
| @@ -347,7 +347,7 @@ static int TypeStrToDstType(const std::string &type_str) { | |||||
| ret = 4; | ret = 4; | ||||
| } else if (type_str == "UInt64") { | } else if (type_str == "UInt64") { | ||||
| ret = 10; | ret = 10; | ||||
| } else if (type_str == "Bool_") { | |||||
| } else if (type_str == "Bool") { | |||||
| ret = 12; | ret = 12; | ||||
| } else { | } else { | ||||
| MS_LOG(INFO) << "Error type str is invailed: " << type_str; | MS_LOG(INFO) << "Error type str is invailed: " << type_str; | ||||
| @@ -51,7 +51,7 @@ const std::map<TypeId, std::string> type_id_str_maps = { | |||||
| const std::map<std::string, std::string> type_str_maps = { | const std::map<std::string, std::string> type_str_maps = { | ||||
| {"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"}, | {"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"}, | ||||
| {"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"}, | {"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"}, | ||||
| {"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool_", "int8"}, {"Float64", "float64"}, | |||||
| {"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool", "int8"}, {"Float64", "float64"}, | |||||
| }; | }; | ||||
| const std::unordered_map<std::string, size_t> type_nbyte_maps = { | const std::unordered_map<std::string, size_t> type_nbyte_maps = { | ||||
| @@ -334,8 +334,8 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL | |||||
| FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { | FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { | ||||
| FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | ||||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||||
| ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||||
| ptrGraph->debug_info()->set_name("hyper_map"); | ptrGraph->debug_info()->set_name("hyper_map"); | ||||
| AnfNodePtr ptrFnArg = nullptr; | AnfNodePtr ptrFnArg = nullptr; | ||||
| @@ -389,7 +389,7 @@ FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tu | |||||
| MS_EXCEPTION_IF_NULL(a_tuple); | MS_EXCEPTION_IF_NULL(a_tuple); | ||||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | FuncGraphPtr ret = std::make_shared<FuncGraph>(); | ||||
| ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ret->debug_info()->set_name("tail"); | ret->debug_info()->set_name("tail"); | ||||
| AnfNodePtr ptrTup = ret->add_parameter(); | AnfNodePtr ptrTup = ret->add_parameter(); | ||||
| @@ -409,7 +409,7 @@ FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list | |||||
| MS_EXCEPTION_IF_NULL(a_list); | MS_EXCEPTION_IF_NULL(a_list); | ||||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | FuncGraphPtr ret = std::make_shared<FuncGraph>(); | ||||
| ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ret->debug_info()->set_name("tail"); | ret->debug_info()->set_name("tail"); | ||||
| AnfNodePtr ptrList = ret->add_parameter(); | AnfNodePtr ptrList = ret->add_parameter(); | ||||
| @@ -481,10 +481,10 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg | |||||
| grads.push_back(b->NewCNode({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)})); | grads.push_back(b->NewCNode({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)})); | ||||
| } | } | ||||
| b->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| b->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| b->set_output(b->NewCNode(grads)); | b->set_output(b->NewCNode(grads)); | ||||
| fg->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| fg->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)})); | fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)})); | ||||
| (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple)); | (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple)); | ||||
| return fg; | return fg; | ||||
| @@ -504,7 +504,7 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, | |||||
| const std::vector<AnfNodePtr> ¶ms_list, const std::vector<AnfNodePtr> &args, | const std::vector<AnfNodePtr> ¶ms_list, const std::vector<AnfNodePtr> &args, | ||||
| bool applyJ) { | bool applyJ) { | ||||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | FuncGraphPtr ret = std::make_shared<FuncGraph>(); | ||||
| ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| auto weights_node = weights; | auto weights_node = weights; | ||||
| if (weights == nullptr && !args.empty()) { | if (weights == nullptr && !args.empty()) { | ||||
| @@ -625,7 +625,7 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp | |||||
| std::ostringstream ss; | std::ostringstream ss; | ||||
| ss << "grad{" << nparam << "}"; | ss << "grad{" << nparam << "}"; | ||||
| dfBuilder->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| dfBuilder->debug_info()->set_name(ss.str()); | dfBuilder->debug_info()->set_name(ss.str()); | ||||
| ParameterPtr param_graph = dfBuilder->add_parameter(); | ParameterPtr param_graph = dfBuilder->add_parameter(); | ||||
| @@ -671,7 +671,7 @@ FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_lis | |||||
| } | } | ||||
| FuncGraphPtr fg_ptr = std::make_shared<FuncGraph>(); | FuncGraphPtr fg_ptr = std::make_shared<FuncGraph>(); | ||||
| fg_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| fg_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| fg_ptr->debug_info()->set_name("list_map"); | fg_ptr->debug_info()->set_name("list_map"); | ||||
| AnfNodePtr fn = fg_ptr->add_parameter(); | AnfNodePtr fn = fg_ptr->add_parameter(); | ||||
| @@ -741,7 +741,7 @@ void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr | |||||
| // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts) | // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts) | ||||
| FuncGraphPtr fgtrue_ptr = std::make_shared<FuncGraph>(); | FuncGraphPtr fgtrue_ptr = std::make_shared<FuncGraph>(); | ||||
| fgtrue_ptr->debug_info()->set_name("ftrue"); | fgtrue_ptr->debug_info()->set_name("ftrue"); | ||||
| fgtrue_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| fgtrue_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNode({NewValueNode(fgnext_ptr), fn, resl}); | CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNode({NewValueNode(fgnext_ptr), fn, resl}); | ||||
| auto inputs = fgtrue_output_cnode->inputs(); | auto inputs = fgtrue_output_cnode->inputs(); | ||||
| @@ -751,7 +751,7 @@ void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr | |||||
| FuncGraphPtr fgfalse_ptr = std::make_shared<FuncGraph>(); | FuncGraphPtr fgfalse_ptr = std::make_shared<FuncGraph>(); | ||||
| fgfalse_ptr->debug_info()->set_name("ffalse"); | fgfalse_ptr->debug_info()->set_name("ffalse"); | ||||
| fgfalse_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| fgfalse_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| fgfalse_ptr->set_output(resl); | fgfalse_ptr->set_output(resl); | ||||
| AnfNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")), | AnfNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")), | ||||
| @@ -808,7 +808,7 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_li | |||||
| } | } | ||||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | FuncGraphPtr ret = std::make_shared<FuncGraph>(); | ||||
| ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| AnfNodePtr p_tup_a = ret->add_parameter(); | AnfNodePtr p_tup_a = ret->add_parameter(); | ||||
| AnfNodePtr p_tup_b = ret->add_parameter(); | AnfNodePtr p_tup_b = ret->add_parameter(); | ||||
| @@ -912,7 +912,7 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_ | |||||
| GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value); | GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value); | ||||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | FuncGraphPtr ret = std::make_shared<FuncGraph>(); | ||||
| ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| AnfNodePtr p_tuple = ret->add_parameter(); | AnfNodePtr p_tuple = ret->add_parameter(); | ||||
| (void)ret->add_parameter(); | (void)ret->add_parameter(); | ||||
| @@ -941,7 +941,7 @@ FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &ar | |||||
| AbstractBasePtrList branches = branches_abs->elements(); | AbstractBasePtrList branches = branches_abs->elements(); | ||||
| if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) { | if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) { | ||||
| FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); | FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); | ||||
| ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| AnfNodePtr functions = ret_graph->add_parameter(); | AnfNodePtr functions = ret_graph->add_parameter(); | ||||
| auto index = ret_graph->add_parameter(); | auto index = ret_graph->add_parameter(); | ||||
| @@ -304,7 +304,7 @@ FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrLi | |||||
| } | } | ||||
| auto new_cnode = BuildNewCNode(func_graph, name_, function_, args_spec_list, func_graph->parameters()); | auto new_cnode = BuildNewCNode(func_graph, name_, function_, args_spec_list, func_graph->parameters()); | ||||
| func_graph->set_output(new_cnode); | func_graph->set_output(new_cnode); | ||||
| func_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| func_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| return func_graph; | return func_graph; | ||||
| } | } | ||||
| } // namespace prim | } // namespace prim | ||||
| @@ -35,7 +35,7 @@ FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList & | |||||
| MS_EXCEPTION_IF_NULL(arg0_list); | MS_EXCEPTION_IF_NULL(arg0_list); | ||||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | FuncGraphPtr ret = std::make_shared<FuncGraph>(); | ||||
| ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ret->debug_info()->set_name("append"); | ret->debug_info()->set_name("append"); | ||||
| AnfNodePtr arg0_node = ret->add_parameter(); | AnfNodePtr arg0_node = ret->add_parameter(); | ||||
| @@ -51,8 +51,8 @@ AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &f | |||||
| FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) { | FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) { | ||||
| // Generate func for leaf nodes | // Generate func for leaf nodes | ||||
| FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | ||||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||||
| ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||||
| ptrGraph->debug_info()->set_name("map"); | ptrGraph->debug_info()->set_name("map"); | ||||
| AnfNodePtr ptrFnArg = nullptr; | AnfNodePtr ptrFnArg = nullptr; | ||||
| if (fn_leaf_ == nullptr) { | if (fn_leaf_ == nullptr) { | ||||
| @@ -237,8 +237,8 @@ AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, c | |||||
| FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) { | FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) { | ||||
| FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | ||||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||||
| ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||||
| ptrGraph->debug_info()->set_name("map"); | ptrGraph->debug_info()->set_name("map"); | ||||
| AnfNodePtr ptrFnArg = nullptr; | AnfNodePtr ptrFnArg = nullptr; | ||||
| @@ -51,7 +51,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_ | |||||
| (void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0); | (void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0); | ||||
| auto ret_graph = std::make_shared<FuncGraph>(); | auto ret_graph = std::make_shared<FuncGraph>(); | ||||
| ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| AnfNodePtr fnNode = ret_graph->add_parameter(); | AnfNodePtr fnNode = ret_graph->add_parameter(); | ||||
| std::vector<AnfNodePtr> elems; | std::vector<AnfNodePtr> elems; | ||||
| @@ -57,7 +57,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spe | |||||
| return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size()); | return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size()); | ||||
| }); | }); | ||||
| FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); | FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); | ||||
| ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| for (size_t idx = 0; idx < args_spec_list.size(); idx++) { | for (size_t idx = 0; idx < args_spec_list.size(); idx++) { | ||||
| (void)ret_graph->add_parameter(); | (void)ret_graph->add_parameter(); | ||||
| } | } | ||||
| @@ -50,6 +50,12 @@ const PrimitivePtr kPrimBoolNot = std::make_shared<Primitive>("bool_not"); | |||||
| const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and"); | const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and"); | ||||
| const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or"); | const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or"); | ||||
| const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq"); | const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq"); | ||||
| const PrimitivePtr kPrimGreater = std::make_shared<Primitive>("Greater"); | |||||
| const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual"); | |||||
| const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less"); | |||||
| const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual"); | |||||
| const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal"); | |||||
| const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual"); | |||||
| // Type introspection | // Type introspection | ||||
| const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof"); | const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof"); | ||||
| @@ -166,17 +172,20 @@ const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul"); | |||||
| const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum"); | const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum"); | ||||
| const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum"); | const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum"); | ||||
| const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square"); | const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square"); | ||||
| const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal"); | |||||
| const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less"); | |||||
| const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual"); | |||||
| const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum"); | const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum"); | ||||
| const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd"); | const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd"); | ||||
| const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar"); | const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar"); | ||||
| const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd"); | const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd"); | ||||
| const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub"); | const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub"); | ||||
| const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow"); | |||||
| const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv"); | |||||
| const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt"); | |||||
| const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal"); | |||||
| const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims"); | |||||
| // NN | // NN | ||||
| const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | ||||
| const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax"); | |||||
| const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax"); | const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax"); | ||||
| const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad"); | const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad"); | ||||
| const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh"); | const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh"); | ||||
| @@ -253,6 +262,7 @@ const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict"); | |||||
| const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | ||||
| const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast"); | const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast"); | ||||
| const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant"); | const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant"); | ||||
| const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat"); | |||||
| // Comm ops | // Comm ops | ||||
| const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | ||||
| @@ -59,6 +59,12 @@ extern const PrimitivePtr kPrimBoolNot; | |||||
| extern const PrimitivePtr kPrimBoolAnd; | extern const PrimitivePtr kPrimBoolAnd; | ||||
| extern const PrimitivePtr kPrimBoolOr; | extern const PrimitivePtr kPrimBoolOr; | ||||
| extern const PrimitivePtr kPrimBoolEq; | extern const PrimitivePtr kPrimBoolEq; | ||||
| extern const PrimitivePtr kPrimGreater; | |||||
| extern const PrimitivePtr kPrimGreaterEqual; | |||||
| extern const PrimitivePtr kPrimLess; | |||||
| extern const PrimitivePtr kPrimLessEqual; | |||||
| extern const PrimitivePtr kPrimEqual; | |||||
| extern const PrimitivePtr kPrimNotEqual; | |||||
| // Type introspection | // Type introspection | ||||
| extern const PrimitivePtr kPrimTypeOf; | extern const PrimitivePtr kPrimTypeOf; | ||||
| @@ -157,6 +163,10 @@ extern const PrimitivePtr KPrimTransData; | |||||
| extern const PrimitivePtr kPrimNMSWithMask; | extern const PrimitivePtr kPrimNMSWithMask; | ||||
| extern const PrimitivePtr kPrimPad; | extern const PrimitivePtr kPrimPad; | ||||
| extern const PrimitivePtr kPrimArgMaxWithValue; | extern const PrimitivePtr kPrimArgMaxWithValue; | ||||
| extern const PrimitivePtr kPrimRealDiv; | |||||
| extern const PrimitivePtr kPrimSqrt; | |||||
| extern const PrimitivePtr kPrimReciprocal; | |||||
| extern const PrimitivePtr kPrimExpandDims; | |||||
| // Maths | // Maths | ||||
| extern const PrimitivePtr kPrimTensorAdd; | extern const PrimitivePtr kPrimTensorAdd; | ||||
| @@ -183,9 +193,11 @@ extern const PrimitivePtr kPrimCumProd; | |||||
| extern const PrimitivePtr kPrimSubscalar; | extern const PrimitivePtr kPrimSubscalar; | ||||
| extern const PrimitivePtr kPrimInplaceAdd; | extern const PrimitivePtr kPrimInplaceAdd; | ||||
| extern const PrimitivePtr kPrimInplaceSub; | extern const PrimitivePtr kPrimInplaceSub; | ||||
| extern const PrimitivePtr kPrimPow; | |||||
| // NN | // NN | ||||
| extern const PrimitivePtr kPrimFlatten; | extern const PrimitivePtr kPrimFlatten; | ||||
| extern const PrimitivePtr kPrimSoftmax; | |||||
| extern const PrimitivePtr kPrimLogSoftmax; | extern const PrimitivePtr kPrimLogSoftmax; | ||||
| extern const PrimitivePtr kPrimLogSoftmaxGrad; | extern const PrimitivePtr kPrimLogSoftmaxGrad; | ||||
| extern const PrimitivePtr kPrimApplyCenteredRMSProp; | extern const PrimitivePtr kPrimApplyCenteredRMSProp; | ||||
| @@ -263,6 +275,7 @@ extern const PrimitivePtr kPrimInDict; | |||||
| extern const PrimitivePtr kPrimNotInDict; | extern const PrimitivePtr kPrimNotInDict; | ||||
| extern const PrimitivePtr kPrimMixedPrecisionCast; | extern const PrimitivePtr kPrimMixedPrecisionCast; | ||||
| extern const PrimitivePtr kPrimIsConsant; | extern const PrimitivePtr kPrimIsConsant; | ||||
| extern const PrimitivePtr kPrimEquivFormat; | |||||
| // Comm ops | // Comm ops | ||||
| extern const PrimitivePtr kPrimAllReduce; | extern const PrimitivePtr kPrimAllReduce; | ||||
| @@ -45,10 +45,19 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas | |||||
| : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) { | : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) { | ||||
| TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal_graph->debug_info())); | TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal_graph->debug_info())); | ||||
| k_graph_ = std::make_shared<FuncGraph>(); | k_graph_ = std::make_shared<FuncGraph>(); | ||||
| if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||||
| std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||||
| k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name)); | |||||
| } | |||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info())); | TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info())); | ||||
| tape_ = std::make_shared<FuncGraph>(); | tape_ = std::make_shared<FuncGraph>(); | ||||
| // Add "_Grad" postfix | |||||
| if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||||
| std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) + "_Grad"; | |||||
| tape_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name)); | |||||
| } | |||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| dout_ = tape_->add_parameter(); | dout_ = tape_->add_parameter(); | ||||
| @@ -368,7 +377,7 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { | |||||
| (void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg))); | (void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg))); | ||||
| (void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal))); | (void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal))); | ||||
| // Reset defer_inline to enable successive inlining | // Reset defer_inline to enable successive inlining | ||||
| primal->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false); | |||||
| primal->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false); | |||||
| auto functor = std::make_shared<DFunctor>(primal, resources_); | auto functor = std::make_shared<DFunctor>(primal, resources_); | ||||
| functor->Init(); | functor->Init(); | ||||
| @@ -37,7 +37,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt | |||||
| auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) { | auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) { | ||||
| if (MsContext::GetInstance()->is_multi_graph_sink()) { | if (MsContext::GetInstance()->is_multi_graph_sink()) { | ||||
| if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { | if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { | ||||
| f->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||||
| f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||||
| } | } | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -78,7 +78,10 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(cons); | MS_EXCEPTION_IF_NULL(cons); | ||||
| auto dt = data->abstract(); | auto dt = data->abstract(); | ||||
| MS_EXCEPTION_IF_NULL(dt); | |||||
| if (dt == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| if (!dt->isa<AbstractClass>()) { | if (!dt->isa<AbstractClass>()) { | ||||
| MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << "."; | MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << "."; | ||||
| } | } | ||||
| @@ -0,0 +1,157 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "optimizer/graph_kernel_reuse.h" | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include <string> | |||||
| #include "./common.h" | |||||
| #include "utils/graph_utils.h" | |||||
| namespace mindspore { | |||||
| /* namespace to support opt */ | |||||
| namespace opt { | |||||
| bool GraphKernelReuse::CompareNode(const AnfNodePtr a, const AnfNodePtr b) { | |||||
| if (a->abstract() && b->abstract()) { | |||||
| auto a_type = a->abstract()->GetTypeTrack(); | |||||
| auto b_type = b->abstract()->GetTypeTrack(); | |||||
| if (a_type != b_type) { | |||||
| return false; | |||||
| } | |||||
| auto a_shape = a->abstract()->GetShapeTrack(); | |||||
| auto b_shape = b->abstract()->GetShapeTrack(); | |||||
| if (a_shape != nullptr && a_shape == b_shape) { | |||||
| return true; | |||||
| } | |||||
| if (a_shape != nullptr && b_shape != nullptr && a_shape->isa<abstract::Shape>() && | |||||
| b_shape->isa<abstract::Shape>()) { | |||||
| return a_shape->cast<abstract::ShapePtr>()->shape() == b_shape->cast<abstract::ShapePtr>()->shape(); | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool GraphKernelReuse::DoReplace(const FuncGraphManagerPtr manager) { | |||||
| bool changed = false; | |||||
| auto fgs = manager->func_graphs(); | |||||
| for (FuncGraphPtr &fg : fgs) { | |||||
| if (!fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||||
| continue; | |||||
| } | |||||
| std::string key = GetValue<std::string>(fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||||
| if (graph_kernel_ops.find(key) != graph_kernel_ops.end()) { | |||||
| if (find(graph_kernel_ops[key].begin(), graph_kernel_ops[key].end(), fg) == graph_kernel_ops[key].end()) { | |||||
| FuncGraphPtr new_fg = nullptr; | |||||
| for (auto &cfg : graph_kernel_ops[key]) { | |||||
| // If two graphs have different size then continue | |||||
| auto fg_topos = TopoSort(fg->get_return()); | |||||
| auto cfg_topos = TopoSort(cfg->get_return()); | |||||
| if (fg_topos.size() != cfg_topos.size()) { | |||||
| continue; | |||||
| } | |||||
| // Compare const tensor | |||||
| bool has_same = true; | |||||
| for (size_t i = 0; i < fg_topos.size(); ++i) { | |||||
| if (IsValueNode<tensor::Tensor>(fg_topos[i])) { | |||||
| if (!IsValueNode<tensor::Tensor>(cfg_topos[i])) { | |||||
| has_same = false; | |||||
| break; | |||||
| } | |||||
| auto tensor1 = GetValueNode<tensor::TensorPtr>(fg_topos[i]); | |||||
| auto tensor2 = GetValueNode<tensor::TensorPtr>(cfg_topos[i]); | |||||
| if (!tensor1->ValueEqual(*tensor2)) { | |||||
| has_same = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!has_same) { | |||||
| continue; | |||||
| } | |||||
| auto fg_input = fg->parameters(); | |||||
| auto cfg_input = cfg->parameters(); | |||||
| if (fg_input.size() != cfg_input.size()) { | |||||
| continue; | |||||
| } | |||||
| // Compare input | |||||
| for (size_t i = 0; i < fg_input.size(); ++i) { | |||||
| if (!CompareNode(fg_input[i], cfg_input[i])) { | |||||
| has_same = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (!has_same) { | |||||
| continue; | |||||
| } | |||||
| // Compare output | |||||
| if (!CompareNode(fg->output(), cfg->output())) { | |||||
| continue; | |||||
| } | |||||
| // Find reusable fg | |||||
| new_fg = cfg; | |||||
| break; | |||||
| } | |||||
| if (new_fg != nullptr) { | |||||
| // Replace current fg with existing fg | |||||
| auto users = fg->func_graph_cnodes_index(); | |||||
| for (auto &iter : users) { | |||||
| auto cnode = iter.first->first->cast<CNodePtr>(); | |||||
| auto new_input = cnode->inputs(); | |||||
| auto main_graph = cnode->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(main_graph); | |||||
| if (IsPrimitiveCNode(cnode, prim::kPrimPartial)) { | |||||
| new_input[1] = NewValueNode(new_fg); | |||||
| } else { | |||||
| new_input[0] = NewValueNode(new_fg); | |||||
| } | |||||
| auto new_cnode = main_graph->NewCNode(new_input); | |||||
| manager->Replace(iter.first->first, new_cnode); | |||||
| changed = true; | |||||
| } | |||||
| } else { | |||||
| // Add current fg to map | |||||
| graph_kernel_ops[key].push_back(fg); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| graph_kernel_ops[key] = {fg}; | |||||
| } | |||||
| } | |||||
| return changed; | |||||
| } | |||||
| bool GraphKernelReuse::ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager) { | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| manager->AddFuncGraph(root); | |||||
| return DoReplace(manager); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H | |||||
| #define MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H | |||||
| #include <mindspore/ccsrc/session/anf_runtime_algorithm.h> | |||||
| #include <unordered_map> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "optimizer/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| // Common subexpression elimination. | |||||
| class GraphKernelReuse { | |||||
| public: | |||||
| GraphKernelReuse() : count(0) {} | |||||
| virtual ~GraphKernelReuse() = default; | |||||
| bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { | |||||
| bool chg = ReuseGraphKernel(root, optimizer->resource()->manager()); | |||||
| return chg; | |||||
| } | |||||
| bool CompareNode(const AnfNodePtr a, const AnfNodePtr other); | |||||
| bool DoReplace(const FuncGraphManagerPtr manager); | |||||
| bool ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager); | |||||
| private: | |||||
| std::unordered_map<std::string, std::vector<FuncGraphPtr>> graph_kernel_ops; | |||||
| int count; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H | |||||
| @@ -41,6 +41,8 @@ | |||||
| #include "optimizer/irpass/incorporate_call.h" | #include "optimizer/irpass/incorporate_call.h" | ||||
| #include "optimizer/irpass/grad_var_prepare.h" | #include "optimizer/irpass/grad_var_prepare.h" | ||||
| #include "optimizer/irpass/param_replace.h" | #include "optimizer/irpass/param_replace.h" | ||||
| #include "optimizer/irpass/mark_interface_fusion.h" | |||||
| #include "optimizer/opt.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -48,7 +50,7 @@ namespace irpass { | |||||
| OptimizeIRPassLib::OptimizeIRPassLib() { | OptimizeIRPassLib::OptimizeIRPassLib() { | ||||
| arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", | arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", | ||||
| {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, | {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, | ||||
| prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); | |||||
| prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); | |||||
| special_op_eliminate_ = | special_op_eliminate_ = | ||||
| MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | ||||
| {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, | {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, | ||||
| @@ -90,7 +92,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| replace_refkey_by_param_ = | replace_refkey_by_param_ = | ||||
| MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM); | MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM); | ||||
| replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam); | replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam); | ||||
| // Gradient transforms | // Gradient transforms | ||||
| expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ); | expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ); | ||||
| minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem); | minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem); | ||||
| @@ -115,6 +116,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| // Incorporation | // Incorporation | ||||
| incorporate_getitem_set_ = | incorporate_getitem_set_ = | ||||
| MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem); | MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem); | ||||
| incorporate_getitem_from_param_ = | |||||
| MakeSubstitution(IncorporateGetitemFromParam(), "incorporate_getitem_from_param", IsCNodeGraphKernel); | |||||
| incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup); | incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup); | ||||
| incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup); | incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup); | ||||
| @@ -124,6 +127,17 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| // Convert | // Convert | ||||
| print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint); | print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint); | ||||
| // Unused parameter eliminate | |||||
| unused_parameter_eliminate_ = | |||||
| MakeSubstitution(UnusedParasEliminater(), "unused_parameter_eliminate", IsCNodeGraphKernel); | |||||
| unused_output_eliminate_ = MakeSubstitution(UnusedOutputEliminater(), "unused_output_eliminate", IsCNodeGraphKernel); | |||||
| // AddN eliminate | |||||
| addn_eliminate_ = MakeSubstitution(AddNEliminater(), "addn_eliminate", IsCNodeGraphKernel); | |||||
| // Mark interface fusion | |||||
| mark_interface_fusion_ = MakeSubstitution(MarkInterfaceFusion(), "mark_interface_fusion", prim::kPrimSelect); | |||||
| } | } | ||||
| ResolveIRPassLib::ResolveIRPassLib() { | ResolveIRPassLib::ResolveIRPassLib() { | ||||
| @@ -84,6 +84,7 @@ class OptimizeIRPassLib { | |||||
| // Incorporation | // Incorporation | ||||
| SubstitutionPtr incorporate_getitem_set_; | SubstitutionPtr incorporate_getitem_set_; | ||||
| SubstitutionPtr incorporate_getitem_from_param_; | |||||
| SubstitutionPtr incorporate_call_; | SubstitutionPtr incorporate_call_; | ||||
| SubstitutionPtr incorporate_call_switch_; | SubstitutionPtr incorporate_call_switch_; | ||||
| @@ -92,6 +93,16 @@ class OptimizeIRPassLib { | |||||
| // Convert | // Convert | ||||
| SubstitutionPtr print_tuple_wrapper_; | SubstitutionPtr print_tuple_wrapper_; | ||||
| // Unused parameter eliminate | |||||
| SubstitutionPtr unused_parameter_eliminate_; | |||||
| SubstitutionPtr unused_output_eliminate_; | |||||
| // AddN eliminate | |||||
| SubstitutionPtr addn_eliminate_; | |||||
| // Fusion | |||||
| SubstitutionPtr mark_interface_fusion_; | |||||
| }; | }; | ||||
| // the collection of irpass for resolve action | // the collection of irpass for resolve action | ||||
| @@ -145,6 +156,23 @@ inline bool IsCNodeGraph(const AnfNodePtr &node) { | |||||
| return IsValueNode<FuncGraph>(inp0); | return IsValueNode<FuncGraph>(inp0); | ||||
| } | } | ||||
| // Check if CNode Input 0 is Func Graph of graph kernel. | |||||
| inline bool IsCNodeGraphKernel(const AnfNodePtr &node) { | |||||
| if (node == nullptr || !node->isa<CNode>()) { | |||||
| return false; | |||||
| } | |||||
| auto inp0 = node->cast<CNodePtr>()->input(0); | |||||
| if (IsValueNode<FuncGraph>(inp0)) { | |||||
| auto fg = GetValueNode<FuncGraphPtr>(inp0); | |||||
| if (fg == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| // Check if CNode Input 0 is CNode | // Check if CNode Input 0 is CNode | ||||
| inline bool IsCNodeDup(const AnfNodePtr &node) { | inline bool IsCNodeDup(const AnfNodePtr &node) { | ||||
| if (node == nullptr || !node->isa<CNode>()) { | if (node == nullptr || !node->isa<CNode>()) { | ||||
| @@ -83,6 +83,216 @@ class MultiplyByZeroOrOne : public AnfVisitor { | |||||
| AnfNodePtr x_{nullptr}; | AnfNodePtr x_{nullptr}; | ||||
| }; | }; | ||||
| // Support class used for checking if all values of a Tensor are equal `check_value_` | |||||
| // Supported data types: double, float/float32, int/int32 | |||||
| class CheckTensorConstant { | |||||
| public: | |||||
| explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {} | |||||
| ~CheckTensorConstant() = default; | |||||
| bool IsTensorConstant(const ValuePtr &value) { | |||||
| if (!value->isa<tensor::Tensor>()) { | |||||
| return false; | |||||
| } | |||||
| auto tensor_ptr = dyn_cast<tensor::Tensor>(value); | |||||
| TypeId tensor_type = tensor_ptr->Dtype()->type_id(); | |||||
| if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { | |||||
| float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c()); | |||||
| for (int i = 0; i < tensor_ptr->DataSize(); i++) { | |||||
| if (fabs(data2[i] - check_value_) > FLT_EPSILON) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } else if (tensor_type == TypeId::kNumberTypeFloat64) { | |||||
| double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c()); | |||||
| for (int i = 0; i < tensor_ptr->DataSize(); i++) { | |||||
| if (fabs(data2[i] - check_value_) > DBL_EPSILON) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { | |||||
| int *data2 = reinterpret_cast<int *>(tensor_ptr->data_c()); | |||||
| for (int i = 0; i < tensor_ptr->DataSize(); i++) { | |||||
| if (data2[i] != check_value_) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| // Un-support Data Types | |||||
| return false; | |||||
| } | |||||
| bool IsTensorScalarConstant(const ValuePtr &value) { | |||||
| if (!value->isa<tensor::Tensor>()) { | |||||
| return false; | |||||
| } | |||||
| auto tensor_ptr = dyn_cast<tensor::Tensor>(value); | |||||
| if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) { | |||||
| return false; | |||||
| } | |||||
| return IsTensorConstant(value); | |||||
| } | |||||
| private: | |||||
| int check_value_; | |||||
| }; | |||||
| // {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} | |||||
| // {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} | |||||
| class TensorMultiplyByZeroOrOne : public AnfVisitor { | |||||
| public: | |||||
| TensorMultiplyByZeroOrOne() : zero_(MakeValue(0)) {} | |||||
| ~TensorMultiplyByZeroOrOne() override = default; | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| Reset(); | |||||
| AnfVisitor::Match(prim::kPrimMul)(node); | |||||
| if (is_zero_) { | |||||
| if (x_->func_graph() != node->func_graph()) { | |||||
| return nullptr; | |||||
| } | |||||
| return NewTensorFilledWithData(node); | |||||
| } | |||||
| if (is_one_) { | |||||
| return NewTensorFilledWithData(node, x_); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| void Visit(const AnfNodePtr &node) override { | |||||
| if (is_zero_ || is_one_) { | |||||
| x_ = node; | |||||
| return; | |||||
| } | |||||
| if (IsParam(node)) { | |||||
| x_ = node; | |||||
| return; | |||||
| } | |||||
| if (IsCNode(node)) { | |||||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||||
| if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) { | |||||
| is_zero_ = true; | |||||
| return; | |||||
| } | |||||
| x_ = node; | |||||
| return; | |||||
| } | |||||
| auto value = node->cast<ValueNodePtr>()->value(); | |||||
| if (CheckTensorConstant(0).IsTensorConstant(value)) { | |||||
| is_zero_ = true; | |||||
| return; | |||||
| } else if (CheckTensorConstant(1).IsTensorConstant(value)) { | |||||
| is_one_ = true; | |||||
| return; | |||||
| } | |||||
| x_ = node; | |||||
| } | |||||
| void Visit(const ValueNodePtr &vnode) override { | |||||
| auto value = vnode->value(); | |||||
| if (CheckTensorConstant(0).IsTensorConstant(value)) { | |||||
| is_zero_ = true; | |||||
| return; | |||||
| } else if (CheckTensorConstant(1).IsTensorConstant(value)) { | |||||
| is_one_ = true; | |||||
| return; | |||||
| } | |||||
| x_ = vnode; | |||||
| } | |||||
| void Reset() { | |||||
| x_ = nullptr; | |||||
| is_one_ = false; | |||||
| is_zero_ = false; | |||||
| } | |||||
| void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) { | |||||
| if (!node->isa<ValueNode>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto value = node->cast<ValueNodePtr>()->value(); | |||||
| if (!value->isa<tensor::Tensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value); | |||||
| return tensor_ptr->data_c(writable); | |||||
| } | |||||
| // Make a new tensor (when possible) with the same shape as of `node` | |||||
| // If x is nullptr then fill new tensor will "0" | |||||
| // If x is a tensor with empty shape then fill new tensor with the single value of x | |||||
| // If x is a tensor with same shape as `node` then return x as result | |||||
| AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr) { | |||||
| if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); | |||||
| std::vector<int> tensor_shape = tensor_abstract->shape()->shape(); | |||||
| auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape); | |||||
| size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); | |||||
| char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true)); | |||||
| if (x == nullptr) { | |||||
| std::memset(data, 0, mem_size); | |||||
| auto new_vnode = NewValueNode(new_tensor_ptr); | |||||
| new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); | |||||
| return new_vnode; | |||||
| } | |||||
| // x is not nullptr | |||||
| if (x->isa<CNode>()) { | |||||
| if ((x->abstract() == nullptr) || !x->abstract()->isa<abstract::AbstractTensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto x_abstract = x->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| std::vector<int> x_shape = x_abstract->shape()->shape(); | |||||
| if (x_shape != tensor_shape) { | |||||
| return nullptr; | |||||
| } | |||||
| return x; | |||||
| } | |||||
| if (!x->isa<ValueNode>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto x_value = x->cast<ValueNodePtr>()->value(); | |||||
| if (!x_value->isa<tensor::Tensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto x_tensor_ptr = dyn_cast<tensor::Tensor>(x_value); | |||||
| if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) { | |||||
| return nullptr; | |||||
| } | |||||
| char *source_data = reinterpret_cast<char *>(GetPointerToTensorData(x)); | |||||
| if (x_tensor_ptr->DataSize() == 1) { | |||||
| for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) { | |||||
| memcpy(source_data, data + i * GetTypeByte(tensor_type_ptr), GetTypeByte(tensor_type_ptr)); | |||||
| } | |||||
| } else { | |||||
| memcpy(source_data, data, mem_size); | |||||
| } | |||||
| auto new_vnode = NewValueNode(new_tensor_ptr); | |||||
| new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); | |||||
| return new_vnode; | |||||
| } | |||||
| private: | |||||
| bool is_zero_{false}, is_one_{false}; | |||||
| ValuePtr zero_; | |||||
| AnfNodePtr x_{nullptr}; | |||||
| }; | |||||
| // {prim::kPrimScalarAdd, X, 0} | // {prim::kPrimScalarAdd, X, 0} | ||||
| // {prim::kPrimScalarAdd, 0, X} | // {prim::kPrimScalarAdd, 0, X} | ||||
| class AddByZero : public AnfVisitor { | class AddByZero : public AnfVisitor { | ||||
| @@ -101,7 +311,8 @@ class AddByZero : public AnfVisitor { | |||||
| } | } | ||||
| void Visit(const AnfNodePtr &node) override { | void Visit(const AnfNodePtr &node) override { | ||||
| if (node->isa<ValueNode>() && *GetValueNode(node) == *zero_) { | |||||
| if (node->isa<ValueNode>() && | |||||
| ((*GetValueNode(node) == *zero_) || CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node)))) { | |||||
| is_zero_ = true; | is_zero_ = true; | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -139,10 +350,22 @@ class TensorAddByZero : public AnfVisitor { | |||||
| is_zero_ = true; | is_zero_ = true; | ||||
| return; | return; | ||||
| } | } | ||||
| if (node->isa<ValueNode>() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) { | |||||
| is_zero_ = true; | |||||
| return; | |||||
| } | |||||
| x_ = node; | x_ = node; | ||||
| } | } | ||||
| void Visit(const ValueNodePtr &vnode) override { | |||||
| auto value = vnode->value(); | |||||
| if (CheckTensorConstant(0).IsTensorConstant(value)) { | |||||
| is_zero_ = true; | |||||
| return; | |||||
| } | |||||
| } | |||||
| void Reset() { | void Reset() { | ||||
| x_ = nullptr; | x_ = nullptr; | ||||
| is_zero_ = false; | is_zero_ = false; | ||||
| @@ -183,29 +406,143 @@ class OptUpdateZeroTensor : public AnfVisitor { | |||||
| // {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} | // {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} | ||||
| class ConstantDuplicateMul : public AnfVisitor { | class ConstantDuplicateMul : public AnfVisitor { | ||||
| public: | public: | ||||
| // Support function to multiply two constant tensors: partially support broadcasting shapes | |||||
| template <typename T> | |||||
| void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, | |||||
| int out_data_size) { | |||||
| T *data_1 = reinterpret_cast<T *>(in_data_1); | |||||
| T *data_2 = reinterpret_cast<T *>(in_data_2); | |||||
| T *data_out = new T[out_data_size]; | |||||
| if (in_data_1_size == 1) { | |||||
| for (int i = 0; i < out_data_size; i++) { | |||||
| data_out[i] = data_1[0]; | |||||
| } | |||||
| } else { | |||||
| for (int i = 0; i < out_data_size; i++) { | |||||
| data_out[i] = data_1[i]; | |||||
| } | |||||
| } | |||||
| if (in_data_2_size == 1) { | |||||
| for (int i = 0; i < out_data_size; i++) { | |||||
| data_out[i] *= data_2[0]; | |||||
| } | |||||
| } else { | |||||
| for (int i = 0; i < out_data_size; i++) { | |||||
| data_out[i] *= data_2[i]; | |||||
| } | |||||
| } | |||||
| *out_data = reinterpret_cast<void *>(data_out); | |||||
| return; | |||||
| } | |||||
| AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) { | |||||
| if (!vnode_1->isa<ValueNode>() || !vnode_2->isa<ValueNode>() || (vnode_1->abstract() == nullptr) || | |||||
| (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto value_1 = GetValueNode(vnode_1); | |||||
| auto value_2 = GetValueNode(vnode_2); | |||||
| if (!value_1->isa<tensor::Tensor>() || !value_2->isa<tensor::Tensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor_ptr_1 = dyn_cast<tensor::Tensor>(value_1); | |||||
| auto tensor_ptr_2 = dyn_cast<tensor::Tensor>(value_2); | |||||
| auto tensor_1_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| auto tensor_2_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| auto tensor_3_abstract = node_3->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); | |||||
| TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); | |||||
| TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); | |||||
| if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || | |||||
| (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<int> tensor_out_shape = tensor_3_abstract->shape()->shape(); | |||||
| int data_out_size = 1; | |||||
| for (auto it : tensor_out_shape) { | |||||
| data_out_size *= it; | |||||
| } | |||||
| if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { | |||||
| return nullptr; | |||||
| } | |||||
| if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { | |||||
| return nullptr; | |||||
| } | |||||
| void *data_out; | |||||
| if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) || | |||||
| (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) { | |||||
| Multiply<float>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), | |||||
| tensor_ptr_2->DataSize(), &data_out, data_out_size); | |||||
| } else { | |||||
| if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) { | |||||
| Multiply<double>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), | |||||
| tensor_ptr_2->DataSize(), &data_out, data_out_size); | |||||
| } else { | |||||
| if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) || | |||||
| (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) { | |||||
| Multiply<int>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), | |||||
| tensor_ptr_2->DataSize(), &data_out, data_out_size); | |||||
| } else { | |||||
| // Un-support data types | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| } | |||||
| auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape); | |||||
| size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); | |||||
| char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true)); | |||||
| memcpy(data, data_out, mem_size); | |||||
| auto new_vnode = NewValueNode(new_tensor_ptr); | |||||
| new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); | |||||
| return new_vnode; | |||||
| } | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | ||||
| Reset(); | Reset(); | ||||
| // {prim::kPrimMul, Tensor1, {...}} | // {prim::kPrimMul, Tensor1, {...}} | ||||
| AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); | AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); | ||||
| if (vnode_ == nullptr || cnode_ == nullptr) { | |||||
| if (vnode_ == nullptr || c_p_node_ == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| if (!IsCNode(c_p_node_)) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto tensor1 = vnode_; | auto tensor1 = vnode_; | ||||
| auto mul = cnode_; | |||||
| auto mul = c_p_node_->cast<CNodePtr>(); | |||||
| Reset(); | Reset(); | ||||
| // {prim::kPrimMul, Tensor2, {...}} | // {prim::kPrimMul, Tensor2, {...}} | ||||
| AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); | AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); | ||||
| if (vnode_ == nullptr || cnode_ == nullptr) { | |||||
| if (vnode_ == nullptr || c_p_node_ == nullptr) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto tensor2 = vnode_; | auto tensor2 = vnode_; | ||||
| auto cnode = cnode_; | |||||
| auto c_p_node = c_p_node_; | |||||
| auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0)); | auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0)); | ||||
| auto fg = node->func_graph(); | auto fg = node->func_graph(); | ||||
| auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); | |||||
| return NewCNode({NewValueNode(PrimMul), cnode, ttmul}, fg); | |||||
| auto new_mul_tensor = MulConstantTensors(tensor1, tensor2, c_p_node); | |||||
| if (new_mul_tensor == nullptr) { | |||||
| auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); | |||||
| return NewCNode({NewValueNode(PrimMul), c_p_node, ttmul}, fg); | |||||
| } | |||||
| return NewCNode({NewValueNode(PrimMul), c_p_node, new_mul_tensor}, fg); | |||||
| } | } | ||||
| void Visit(const AnfNodePtr &node) override { | void Visit(const AnfNodePtr &node) override { | ||||
| @@ -213,19 +550,40 @@ class ConstantDuplicateMul : public AnfVisitor { | |||||
| vnode_ = node; | vnode_ = node; | ||||
| } | } | ||||
| if (IsCNode(node)) { | |||||
| cnode_ = node->cast<CNodePtr>(); | |||||
| if (IsCNode(node) || IsParam(node)) { | |||||
| c_p_node_ = node; | |||||
| } | } | ||||
| } | } | ||||
| void Reset() { | void Reset() { | ||||
| vnode_ = nullptr; | vnode_ = nullptr; | ||||
| cnode_ = nullptr; | |||||
| c_p_node_ = nullptr; | |||||
| } | } | ||||
| private: | private: | ||||
| AnfNodePtr vnode_; | AnfNodePtr vnode_; | ||||
| CNodePtr cnode_; | |||||
| AnfNodePtr c_p_node_; | |||||
| }; | |||||
| class PowerOneEliminate : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimPow) || node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||||
| if (!IsValueNode<Scalar>(inputs[2])) { | |||||
| return nullptr; | |||||
| } | |||||
| auto scalar = GetValueNode<ScalarPtr>(inputs[2]); | |||||
| if (scalar->isa<FloatImm>() && GetValue<float>(scalar) == 1.0) { | |||||
| return inputs[1]; | |||||
| } else if (scalar->isa<IntergerImm>() && GetValue<int>(scalar) == 1) { | |||||
| return inputs[1]; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| }; | }; | ||||
| // grad = AllReduce(grad) / worker_number | // grad = AllReduce(grad) / worker_number | ||||
| @@ -341,17 +699,21 @@ class ArithmeticSimplify { | |||||
| public: | public: | ||||
| ArithmeticSimplify() | ArithmeticSimplify() | ||||
| : multiply_by_zero_or_one_(), | : multiply_by_zero_or_one_(), | ||||
| tensor_multiply_by_zero_or_one_(), | |||||
| add_by_zero_(), | add_by_zero_(), | ||||
| tensor_add_by_zero_(), | tensor_add_by_zero_(), | ||||
| identity_(prim::kPrimIdentity), | identity_(prim::kPrimIdentity), | ||||
| opt_update_zero_tensor_(), | opt_update_zero_tensor_(), | ||||
| constant_duplicate_mul_() { | |||||
| constant_duplicate_mul_(), | |||||
| power_one_() { | |||||
| eliminaters_.emplace_back(multiply_by_zero_or_one_); | eliminaters_.emplace_back(multiply_by_zero_or_one_); | ||||
| eliminaters_.emplace_back(tensor_multiply_by_zero_or_one_); | |||||
| eliminaters_.emplace_back(add_by_zero_); | eliminaters_.emplace_back(add_by_zero_); | ||||
| eliminaters_.emplace_back(tensor_add_by_zero_); | eliminaters_.emplace_back(tensor_add_by_zero_); | ||||
| eliminaters_.emplace_back(identity_); | eliminaters_.emplace_back(identity_); | ||||
| eliminaters_.emplace_back(opt_update_zero_tensor_); | eliminaters_.emplace_back(opt_update_zero_tensor_); | ||||
| eliminaters_.emplace_back(constant_duplicate_mul_); | eliminaters_.emplace_back(constant_duplicate_mul_); | ||||
| eliminaters_.emplace_back(power_one_); | |||||
| } | } | ||||
| ~ArithmeticSimplify() = default; | ~ArithmeticSimplify() = default; | ||||
| @@ -368,11 +730,13 @@ class ArithmeticSimplify { | |||||
| private: | private: | ||||
| MultiplyByZeroOrOne multiply_by_zero_or_one_; | MultiplyByZeroOrOne multiply_by_zero_or_one_; | ||||
| TensorMultiplyByZeroOrOne tensor_multiply_by_zero_or_one_; | |||||
| AddByZero add_by_zero_; | AddByZero add_by_zero_; | ||||
| TensorAddByZero tensor_add_by_zero_; | TensorAddByZero tensor_add_by_zero_; | ||||
| PrimEliminater identity_; | PrimEliminater identity_; | ||||
| OptUpdateZeroTensor opt_update_zero_tensor_; | OptUpdateZeroTensor opt_update_zero_tensor_; | ||||
| ConstantDuplicateMul constant_duplicate_mul_; | ConstantDuplicateMul constant_duplicate_mul_; | ||||
| PowerOneEliminate power_one_; | |||||
| std::vector<TransformFuncType> eliminaters_{}; | std::vector<TransformFuncType> eliminaters_{}; | ||||
| }; | }; | ||||
| } // namespace irpass | } // namespace irpass | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <memory> | #include <memory> | ||||
| #include <unordered_set> | |||||
| #include "optimizer/irpass.h" | #include "optimizer/irpass.h" | ||||
| #include "optimizer/optimizer.h" | #include "optimizer/optimizer.h" | ||||
| @@ -28,7 +29,6 @@ | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/func_graph_cloner.h" | #include "ir/func_graph_cloner.h" | ||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace irpass { | namespace irpass { | ||||
| @@ -81,13 +81,32 @@ class IncorporateGetitem : public AnfVisitor { | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | ||||
| Reset(); | Reset(); | ||||
| AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int32Imm>})(node); | AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int32Imm>})(node); | ||||
| if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| if (node->func_graph() != nullptr && idx_ >= 0 && fg_ != nullptr) { | |||||
| auto new_fg = getitem_transform_(fg_, idx_); | |||||
| (void)args_.insert(args_.begin(), NewValueNode(new_fg)); | |||||
| return node->func_graph()->NewCNode(args_); | |||||
| if (fg_->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||||
| // If graph kernel has muti output, do not split. | |||||
| // some graph kernel output has EnvInstance node or DeadCode node should split. | |||||
| auto output = fg_->output(); | |||||
| if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { | |||||
| auto output_cnode = output->cast<CNodePtr>(); | |||||
| auto outputs = output_cnode->inputs(); | |||||
| int real_output_cnt = 0; | |||||
| for (size_t i = 1; i < outputs.size(); ++i) { | |||||
| if (IsCNode(outputs[i]) || IsValueNode<tensor::Tensor>(outputs[i]) || IsParam(outputs[i])) { | |||||
| real_output_cnt++; | |||||
| if (real_output_cnt > 1) { | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| return nullptr; | |||||
| auto new_fg = getitem_transform_(fg_, idx_); | |||||
| (void)args_.insert(args_.begin(), NewValueNode(new_fg)); | |||||
| return node->func_graph()->NewCNode(args_); | |||||
| } | } | ||||
| void Visit(const CNodePtr &cnode) override { | void Visit(const CNodePtr &cnode) override { | ||||
| @@ -115,6 +134,172 @@ class IncorporateGetitem : public AnfVisitor { | |||||
| internal::GetitemTransform getitem_transform_; | internal::GetitemTransform getitem_transform_; | ||||
| }; | }; | ||||
| class IncorporateGetitemFromParam : public AnfVisitor { | |||||
| public: | |||||
| void Process(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr ¶m, size_t input_idx) { | |||||
| auto mng = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| auto &node_users = mng->node_users(); | |||||
| if (node_users.find(param) == node_users.end() || node_users[param].empty()) { | |||||
| args_.push_back(cnode->input(input_idx + 1)); | |||||
| return; | |||||
| } | |||||
| for (auto &user : node_users[param]) { | |||||
| if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) { | |||||
| // we do not process this case. | |||||
| args_.push_back(cnode->input(input_idx + 1)); | |||||
| return; | |||||
| } | |||||
| } | |||||
| // update new args. | |||||
| if (IsPrimitiveCNode(cnode->input(input_idx + 1), prim::kPrimMakeTuple)) { | |||||
| // case 1 | |||||
| replace_parameters_[input_idx] = true; | |||||
| need_update_ = true; | |||||
| auto make_tuple_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>(); | |||||
| auto &make_tuple_cnode_inputs = make_tuple_cnode->inputs(); | |||||
| inputs_num_[input_idx] = make_tuple_cnode_inputs.size() - 1; | |||||
| args_.insert(args_.end(), make_tuple_cnode_inputs.begin() + 1, make_tuple_cnode_inputs.end()); | |||||
| } else { | |||||
| // case 2 | |||||
| auto prev_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>(); | |||||
| auto prev_fg = GetValueNode<FuncGraphPtr>(prev_cnode->input(0)); | |||||
| auto fg_output = prev_fg->output(); | |||||
| if (!IsPrimitiveCNode(fg_output, prim::kPrimMakeTuple)) { | |||||
| MS_LOG(ERROR) << "The return of: " << prev_fg->ToString() | |||||
| << " should be a make tuple, but got: " << fg_output->DebugString(); | |||||
| return; | |||||
| } | |||||
| replace_parameters_[input_idx] = true; | |||||
| need_update_ = true; | |||||
| auto make_tuple_cnode = fg_output->cast<CNodePtr>(); | |||||
| inputs_num_[input_idx] = make_tuple_cnode->inputs().size() - 1; | |||||
| for (size_t output_i = 0; output_i < inputs_num_[input_idx]; ++output_i) { | |||||
| auto new_getitem = | |||||
| func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), prev_cnode, NewValueNode(SizeToInt(output_i))}); | |||||
| auto aptr = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(SizeToInt(output_i))); | |||||
| new_getitem->input(2)->set_abstract(aptr); | |||||
| new_getitem->set_abstract(make_tuple_cnode->input(output_i + 1)->abstract()); | |||||
| args_.push_back(new_getitem); | |||||
| } | |||||
| } | |||||
| } | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| if (node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| Reset(); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (cnode == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto &inputs = cnode->inputs(); | |||||
| auto fg = GetValueNode<FuncGraphPtr>(inputs[0]); | |||||
| if (fg == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto mng = fg->manager(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| auto parameters = fg->parameters(); | |||||
| if (parameters.size() != inputs.size() - 1) { | |||||
| return nullptr; | |||||
| } | |||||
| replace_parameters_ = std::vector<bool>(parameters.size(), false); | |||||
| inputs_num_ = std::vector<size_t>(parameters.size(), 1); | |||||
| auto node_fg = node->func_graph(); | |||||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||||
| if (IsPrimitiveCNode(inputs[i], prim::kPrimMakeTuple) || IsCNodeGraphKernel(inputs[i])) { | |||||
| Process(node_fg, cnode, parameters[i - 1], i - 1); | |||||
| } else { | |||||
| args_.push_back(inputs[i]); | |||||
| } | |||||
| } | |||||
| if (!need_update_) { | |||||
| return nullptr; | |||||
| } | |||||
| FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp")); | |||||
| mng->AddFuncGraph(new_fg); | |||||
| auto node_users = mng->node_users(); | |||||
| std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters(); | |||||
| std::vector<AnfNodePtr> new_parameters; | |||||
| size_t curr_input_idx{0}; | |||||
| for (size_t param_i = 0; param_i < new_fg_parameters.size(); ++param_i) { | |||||
| if (!replace_parameters_[param_i]) { | |||||
| if (parameters[param_i]->abstract() != nullptr) { | |||||
| new_fg_parameters[param_i]->set_abstract(parameters[param_i]->abstract()); | |||||
| } | |||||
| new_parameters.push_back(new_fg_parameters[param_i]); | |||||
| curr_input_idx++; | |||||
| continue; | |||||
| } | |||||
| // make a new parameter. | |||||
| for (size_t input_i = 0; input_i < inputs_num_[param_i]; ++input_i) { | |||||
| auto new_param = std::make_shared<Parameter>(new_fg); | |||||
| new_param->set_abstract(args_.at(curr_input_idx)->abstract()); | |||||
| // update users of new parameter. | |||||
| for (auto &user : node_users[new_fg_parameters[param_i]]) { | |||||
| idx_ = -1; | |||||
| AnfVisitor::Match(prim::kPrimTupleGetItem, {IsParam, IsValueNode<Int32Imm>})(user.first); | |||||
| if (idx_ == -1) { | |||||
| MS_LOG(ERROR) << "User of: " << new_fg_parameters[param_i]->DebugString() | |||||
| << " must be tuple getitem here, but got: " << user.first->DebugString(); | |||||
| return nullptr; | |||||
| } | |||||
| if (input_i == IntToSize(idx_)) { | |||||
| for (auto &sub_user : node_users[user.first]) { | |||||
| auto sub_user_cnode = sub_user.first->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(sub_user_cnode); | |||||
| sub_user_cnode->set_input(sub_user.second, new_param); | |||||
| (void)mng->Replace(sub_user.first, sub_user_cnode); | |||||
| } | |||||
| } | |||||
| } | |||||
| // (void)mng->Replace(new_fg_parameters[param_i], new_param); | |||||
| new_parameters.push_back(new_param); | |||||
| curr_input_idx++; | |||||
| } | |||||
| } | |||||
| mng->SetParameters(new_fg, new_parameters); | |||||
| (void)args_.insert(args_.begin(), NewValueNode(new_fg)); | |||||
| auto new_call = node_fg->NewCNode(args_); | |||||
| new_call->set_abstract(node->abstract()); | |||||
| return new_call; | |||||
| } | |||||
| void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue<int>(vnode->value()); } | |||||
| void Visit(const CNodePtr &cnode) override {} | |||||
| void Reset() { | |||||
| replace_parameters_.clear(); | |||||
| args_.clear(); | |||||
| inputs_num_.clear(); | |||||
| need_update_ = false; | |||||
| idx_ = -1; | |||||
| } | |||||
| private: | |||||
| std::vector<bool> replace_parameters_{}; | |||||
| std::vector<AnfNodePtr> args_{}; | |||||
| std::vector<size_t> inputs_num_{}; | |||||
| bool need_update_{false}; | |||||
| int idx_{-1}; | |||||
| }; | |||||
| // {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C} | // {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C} | ||||
| class IncorporateGetitemSwitch : public AnfVisitor { | class IncorporateGetitemSwitch : public AnfVisitor { | ||||
| public: | public: | ||||
| @@ -86,20 +86,10 @@ bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) { | |||||
| bool IsInside(FuncGraphPtr, const AnfNodePtr &node) { | bool IsInside(FuncGraphPtr, const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | MS_EXCEPTION_IF_NULL(node->func_graph()); | ||||
| auto &flags = node->func_graph()->flags(); | |||||
| if (flags.find("inline_inside") != flags.end()) { | |||||
| return flags["inline_inside"]; | |||||
| } | |||||
| return false; | |||||
| return node->func_graph()->has_flag("inline_inside"); | |||||
| } | } | ||||
| bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { | |||||
| auto &flags = fg->flags(); | |||||
| if (flags.find("core") != flags.end()) { | |||||
| return flags["core"]; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); } | |||||
| bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; } | bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; } | ||||
| @@ -123,6 +113,13 @@ class InlinerBase : public AnfVisitor { | |||||
| if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) { | if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // Do not inline GraphKernel to Cell. | |||||
| if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||||
| // If the GraphKernel only contains a return node, we make it inlined. | |||||
| if (fg->nodes().size() - fg->parameters().size() > 1) { | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| Reset(); | Reset(); | ||||
| bool is_match = false; | bool is_match = false; | ||||
| @@ -0,0 +1,86 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H | |||||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H | |||||
| #include <string> | |||||
| #include <sstream> | |||||
| #include <unordered_map> | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "optimizer/optimizer.h" | |||||
| #include "optimizer/irpass.h" | |||||
| #include "ir/visitor.h" | |||||
| #include "operator/ops.h" | |||||
| #include "utils/graph_utils.h" | |||||
| #include "operator/composite/composite.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace irpass { | |||||
| static int count = 0; | |||||
| std::string GetFusionNumber() { | |||||
| std::stringstream ss; | |||||
| ss << std::setw(4) << std::setfill('0') << count; | |||||
| std::string num = ss.str(); | |||||
| ++count; | |||||
| return "_" + num; | |||||
| } | |||||
| // Mark CNodes which can be merged in kernel build | |||||
| class MarkInterfaceFusion : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| if (node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsPrimitiveCNode(node, prim::kPrimSelect)) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| auto condition = cnode->input(1); | |||||
| std::string cmp; | |||||
| std::unordered_map<std::string, std::string> cmp_list = {{"GreaterEqual", "GE"}, {"Greater", "GT"}, | |||||
| {"LessEqual", "LE"}, {"Less", "LT"}, | |||||
| {"Equal", "EQ"}, {"NotEqual", "NE"}}; | |||||
| if (IsPrimitiveCNode(condition)) { | |||||
| auto prim_name = GetCNodeFuncName(condition->cast<CNodePtr>()); | |||||
| if (cmp_list.count(prim_name) != 0) { | |||||
| // Mark Select and compare node | |||||
| cmp = cmp_list[prim_name]; | |||||
| auto cnt = GetFusionNumber(); | |||||
| AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), condition); | |||||
| AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt + "_end"), node); | |||||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||||
| if (IsPrimitiveCNode(cnode->input(i), prim::kPrimZerosLike)) { | |||||
| AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), cnode->input(i)); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| void Visit(const AnfNodePtr &) override {} | |||||
| private: | |||||
| AnfNodePtr y_{nullptr}; | |||||
| }; | |||||
| } // namespace irpass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H | |||||
| @@ -19,6 +19,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <memory> | |||||
| #include "optimizer/irpass.h" | #include "optimizer/irpass.h" | ||||
| #include "optimizer/optimizer.h" | #include "optimizer/optimizer.h" | ||||
| @@ -196,6 +197,131 @@ class AddNZeroFilter : public AnfVisitor { | |||||
| std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{}; | std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{}; | ||||
| bool has_zero_like_{false}; | bool has_zero_like_{false}; | ||||
| }; | }; | ||||
| // {PrimAddN, {kPrimMakeTuple, Xs}} | |||||
| // Akg don't support AddN(ValueNode, Tensor, ...), converted to TensorAdd. | |||||
| // case0: AddN(inputs)(inputs size < 2) -> error | |||||
| // case1: AddN(inputs)(all inputs is ValueNode) -> error | |||||
| // case2: AddN(inputs)(inputs size = 2) -> TensorAdd(Tensor, Tensor) | |||||
| // case3: AddN(ValueNode, Tensor, Tensor, ...)(has one ValueNode input) | |||||
| // -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...)) | |||||
| class AddNEliminater : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| if (!node->isa<CNode>() || node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||||
| auto fg = GetValueNode<FuncGraphPtr>(inputs[0]); | |||||
| MS_EXCEPTION_IF_NULL(fg); | |||||
| auto mng = fg->manager(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| if (fg->recursive()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg")); | |||||
| mng->AddFuncGraph(new_fg); | |||||
| need_update_ = false; | |||||
| bool changed = false; | |||||
| do { | |||||
| changed = false; | |||||
| changed |= Process(new_fg); | |||||
| } while (changed); | |||||
| if (!need_update_) { | |||||
| return nullptr; | |||||
| } else { | |||||
| auto new_sx = inputs; | |||||
| new_sx[0] = NewValueNode(new_fg); | |||||
| return node->func_graph()->NewCNode(new_sx); | |||||
| } | |||||
| } | |||||
| bool Process(const FuncGraphPtr &func_graph) { | |||||
| auto mng = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| auto nodes = TopoSort(func_graph->output()); | |||||
| bool changed = false; | |||||
| for (size_t i = 0; i < nodes.size(); ++i) { | |||||
| auto node = nodes[i]; | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { | |||||
| continue; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto &tuple_input = cnode->input(1); | |||||
| MS_EXCEPTION_IF_NULL(tuple_input); | |||||
| auto tuple_input_cnode = tuple_input->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tuple_input_cnode); | |||||
| auto &tuple_inputs = tuple_input_cnode->inputs(); | |||||
| if (tuple_inputs.size() < 3) { | |||||
| // case0: inputs size < 2, error | |||||
| MS_EXCEPTION(ArgumentError) << "Inputs size of AddN less than 2. " << cnode->DebugString(2); | |||||
| } | |||||
| int valuenode_num = | |||||
| std::accumulate(tuple_inputs.begin() + 1, tuple_inputs.end(), 0, [](int accumulator, const AnfNodePtr &node) { | |||||
| if (IsValueNode<tensor::Tensor>(node)) { | |||||
| return accumulator + 1; | |||||
| } else { | |||||
| return accumulator; | |||||
| } | |||||
| }); | |||||
| if (IntToSize(valuenode_num) == tuple_inputs.size()) { | |||||
| // case1: all inputs is ValueNode, error | |||||
| MS_EXCEPTION(ArgumentError) << "All inputs of AddN is ValueNode. " << cnode->DebugString(2); | |||||
| } | |||||
| if (tuple_inputs.size() == 3) { | |||||
| // case2: inputs size = 2, -> TensorAdd(Tensor, Tensor) | |||||
| MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2); | |||||
| ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations"); | |||||
| std::vector<AnfNodePtr> new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1], | |||||
| tuple_inputs[2]}; | |||||
| mng->Replace(node, func_graph->NewCNode(new_xs)); | |||||
| changed = true; | |||||
| continue; | |||||
| } | |||||
| auto first_valuenode = std::find_if(tuple_inputs.begin() + 1, tuple_inputs.end(), | |||||
| [](const AnfNodePtr &node) { return IsValueNode<tensor::Tensor>(node); }); | |||||
| if (first_valuenode == tuple_inputs.end()) { | |||||
| // no ValueNode input found. | |||||
| continue; | |||||
| } else { | |||||
| // case3: has one ValueNode input -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...)) | |||||
| std::vector<AnfNodePtr> make_tuple_new_xs{ | |||||
| NewValueNode(prim::kPrimMakeTuple), | |||||
| }; | |||||
| std::for_each(tuple_inputs.begin() + 1, tuple_inputs.end(), | |||||
| [&make_tuple_new_xs, &first_valuenode](const AnfNodePtr &node) { | |||||
| if (node != *first_valuenode) { | |||||
| make_tuple_new_xs.push_back(node); | |||||
| } | |||||
| }); | |||||
| ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations"); | |||||
| auto new_addn = func_graph->NewCNode( | |||||
| {func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)}); | |||||
| ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations"); | |||||
| auto new_add = | |||||
| func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn}); | |||||
| (void)mng->Replace(node, new_add); | |||||
| changed = true; | |||||
| continue; | |||||
| } | |||||
| } | |||||
| need_update_ |= changed; | |||||
| return changed; | |||||
| } | |||||
| private: | |||||
| bool need_update_{false}; | |||||
| }; | |||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -79,7 +79,7 @@ class ReduceOneEliminater : public AnfVisitor { | |||||
| } | } | ||||
| void Visit(const AnfNodePtr &node) override { | void Visit(const AnfNodePtr &node) override { | ||||
| if (x_ == nullptr) { | |||||
| if (!IsVNode(node) && x_ == nullptr) { | |||||
| if (IsValueNode<tensor::Tensor>(node)) { | if (IsValueNode<tensor::Tensor>(node)) { | ||||
| is_tensor_ = true; | is_tensor_ = true; | ||||
| } | } | ||||
| @@ -23,6 +23,8 @@ | |||||
| #include "optimizer/irpass.h" | #include "optimizer/irpass.h" | ||||
| #include "ir/visitor.h" | #include "ir/visitor.h" | ||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "utils/graph_utils.h" | |||||
| #include "operator/composite/composite.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -36,6 +38,7 @@ class MakeRefEliminater : public AnfVisitor { | |||||
| this->y_ = node; | this->y_ = node; | ||||
| return true; | return true; | ||||
| }; | }; | ||||
| AnfVisitor::Match(prim::kPrimMakeRef, {IsNode, gety, IsNode})(node); | AnfVisitor::Match(prim::kPrimMakeRef, {IsNode, gety, IsNode})(node); | ||||
| return y_; | return y_; | ||||
| } | } | ||||
| @@ -142,7 +142,7 @@ class ResetDeferInline : public AnfVisitor { | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | ||||
| if (IsValueNode<FuncGraph>(node)) { | if (IsValueNode<FuncGraph>(node)) { | ||||
| auto fg = GetValueNode<FuncGraphPtr>(node); | auto fg = GetValueNode<FuncGraphPtr>(node); | ||||
| fg->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false); | |||||
| fg->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false); | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <utility> | #include <utility> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <unordered_set> | |||||
| #include "optimizer/irpass.h" | #include "optimizer/irpass.h" | ||||
| #include "optimizer/optimizer.h" | #include "optimizer/optimizer.h" | ||||
| @@ -41,7 +42,7 @@ class SpecializeTransform { | |||||
| ~SpecializeTransform() = default; | ~SpecializeTransform() = default; | ||||
| FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector<FuncGraphPtr> graph_args, | FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector<FuncGraphPtr> graph_args, | ||||
| std::vector<PrimitivePtr> prim_args) { | |||||
| std::vector<PrimitivePtr> prim_args, std::vector<tensor::TensorPtr> value_args) { | |||||
| if (cache_.count(func_graph) == 0) { | if (cache_.count(func_graph) == 0) { | ||||
| cache_[func_graph] = {}; | cache_[func_graph] = {}; | ||||
| } | } | ||||
| @@ -69,6 +70,13 @@ class SpecializeTransform { | |||||
| (void)mng->Replace(params[i], arg); | (void)mng->Replace(params[i], arg); | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (value_args[i] != nullptr) { | |||||
| auto const_tensor = *value_args[i]; | |||||
| auto const_tensor_ptr = std::make_shared<tensor::Tensor>(const_tensor); | |||||
| AnfNodePtr arg = NewValueNode(const_tensor_ptr); | |||||
| (void)mng->Replace(params[i], arg); | |||||
| continue; | |||||
| } | |||||
| new_params.push_back(params[i]); | new_params.push_back(params[i]); | ||||
| } | } | ||||
| @@ -108,6 +116,7 @@ class SpecializeOnGraphArguments : public AnfVisitor { | |||||
| std::vector<FuncGraphPtr> graph_args; | std::vector<FuncGraphPtr> graph_args; | ||||
| std::vector<PrimitivePtr> prim_args; | std::vector<PrimitivePtr> prim_args; | ||||
| std::vector<tensor::TensorPtr> value_node_args; | |||||
| std::vector<AnfNodePtr> new_xs; | std::vector<AnfNodePtr> new_xs; | ||||
| bool hasVNode = false; | bool hasVNode = false; | ||||
| for (size_t i = 1; i < inputs.size(); i++) { | for (size_t i = 1; i < inputs.size(); i++) { | ||||
| @@ -115,15 +124,24 @@ class SpecializeOnGraphArguments : public AnfVisitor { | |||||
| auto fg_vnode = GetValueNode<FuncGraphPtr>(inputs[i]); | auto fg_vnode = GetValueNode<FuncGraphPtr>(inputs[i]); | ||||
| graph_args.push_back(fg_vnode); | graph_args.push_back(fg_vnode); | ||||
| prim_args.emplace_back(nullptr); | prim_args.emplace_back(nullptr); | ||||
| value_node_args.emplace_back(nullptr); | |||||
| hasVNode = true; | hasVNode = true; | ||||
| } else if (IsValueNode<Primitive>(inputs[i])) { | } else if (IsValueNode<Primitive>(inputs[i])) { | ||||
| auto p_vnode = GetValueNode<PrimitivePtr>(inputs[i]); | auto p_vnode = GetValueNode<PrimitivePtr>(inputs[i]); | ||||
| graph_args.emplace_back(nullptr); | graph_args.emplace_back(nullptr); | ||||
| prim_args.push_back(p_vnode); | prim_args.push_back(p_vnode); | ||||
| value_node_args.emplace_back(nullptr); | |||||
| hasVNode = true; | |||||
| } else if (IsValueNode<tensor::Tensor>(inputs[i])) { | |||||
| tensor::TensorPtr t_vnode = GetValueNode<tensor::TensorPtr>(inputs[i]); | |||||
| graph_args.emplace_back(nullptr); | |||||
| prim_args.emplace_back(nullptr); | |||||
| value_node_args.emplace_back(t_vnode); | |||||
| hasVNode = true; | hasVNode = true; | ||||
| } else { | } else { | ||||
| graph_args.emplace_back(nullptr); | graph_args.emplace_back(nullptr); | ||||
| prim_args.emplace_back(nullptr); | prim_args.emplace_back(nullptr); | ||||
| value_node_args.emplace_back(nullptr); | |||||
| new_xs.push_back(inputs[i]); | new_xs.push_back(inputs[i]); | ||||
| } | } | ||||
| } | } | ||||
| @@ -132,7 +150,7 @@ class SpecializeOnGraphArguments : public AnfVisitor { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args); | |||||
| auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, value_node_args); | |||||
| (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); | (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); | ||||
| return node->func_graph()->NewCNode(new_xs); | return node->func_graph()->NewCNode(new_xs); | ||||
| @@ -141,6 +159,146 @@ class SpecializeOnGraphArguments : public AnfVisitor { | |||||
| private: | private: | ||||
| internal::SpecializeTransform specialize_transform_; | internal::SpecializeTransform specialize_transform_; | ||||
| }; | }; | ||||
| // Eliminate unused parameters. | |||||
| // {G, Xs} | |||||
| class UnusedParasEliminater : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| if (!node->isa<CNode>() || node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto &inputs = cnode->inputs(); | |||||
| auto fg = GetValueNode<FuncGraphPtr>(inputs[0]); | |||||
| MS_EXCEPTION_IF_NULL(fg); | |||||
| std::vector<AnfNodePtr> parameters = fg->parameters(); | |||||
| size_t size = parameters.size(); | |||||
| if (size != inputs.size() - 1) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<AnfNodePtr> new_xs; | |||||
| std::vector<bool> keep_parameters; | |||||
| auto mng = fg->manager(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| auto &node_users = mng->node_users(); | |||||
| bool has_unused_para = false; | |||||
| for (size_t i = 0; i < size; ++i) { | |||||
| auto iter = node_users.find(parameters[i]); | |||||
| if (iter != node_users.end() && !iter->second.empty()) { | |||||
| keep_parameters.push_back(true); | |||||
| new_xs.push_back(inputs[i + 1]); | |||||
| continue; | |||||
| } | |||||
| keep_parameters.push_back(false); | |||||
| has_unused_para = true; | |||||
| } | |||||
| if (!has_unused_para) { | |||||
| return nullptr; | |||||
| } | |||||
| FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp")); | |||||
| mng->AddFuncGraph(new_fg); | |||||
| std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters(); | |||||
| std::vector<AnfNodePtr> new_parameters; | |||||
| for (size_t i = 0; i < size; i++) { | |||||
| if (keep_parameters[i]) { | |||||
| if (parameters[i]->abstract() != nullptr) { | |||||
| new_fg_parameters[i]->set_abstract(parameters[i]->abstract()); | |||||
| } | |||||
| new_parameters.push_back(new_fg_parameters[i]); | |||||
| } | |||||
| } | |||||
| mng->SetParameters(new_fg, new_parameters); | |||||
| (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); | |||||
| return node->func_graph()->NewCNode(new_xs); | |||||
| } | |||||
| }; | |||||
| // Eliminate unused outputs. | |||||
| // {G, Xs} | |||||
| class UnusedOutputEliminater : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| if (!node->isa<CNode>() || node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||||
| auto fg = GetValueNode<FuncGraphPtr>(inputs[0]); | |||||
| MS_EXCEPTION_IF_NULL(fg); | |||||
| auto mng = fg->manager(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| if (fg->recursive()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg")); | |||||
| mng->AddFuncGraph(new_fg); | |||||
| auto new_fg_output = new_fg->output(); | |||||
| if (!IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto output_cnode = new_fg_output->cast<CNodePtr>(); | |||||
| auto &node_users = mng->node_users(); | |||||
| if (node_users.count(node) == 0 || node_users[node].empty()) { | |||||
| return nullptr; | |||||
| } | |||||
| std::unordered_set<int> used_output_idx; | |||||
| std::vector<std::pair<AnfNodePtr, int>> all_users; | |||||
| for (auto &node_user : node_users[node]) { | |||||
| if (!IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto user_cnode = node_user.first->cast<CNodePtr>(); | |||||
| size_t used_idx = GetValue<int>(user_cnode->input(2)->cast<ValueNodePtr>()->value()); | |||||
| used_output_idx.insert(used_idx); | |||||
| all_users.push_back(std::make_pair(node_user.first, used_idx)); | |||||
| } | |||||
| if (used_output_idx.size() >= output_cnode->inputs().size() - 1) { | |||||
| // all output has users. | |||||
| return nullptr; | |||||
| } | |||||
| if (used_output_idx.empty()) { | |||||
| // we do not process this case. | |||||
| return nullptr; | |||||
| } else if (used_output_idx.size() == 1) { | |||||
| // after eliminate, only one output left. | |||||
| new_fg->set_output(output_cnode->input(*used_output_idx.begin() + 1)); | |||||
| // update users. | |||||
| for (auto &ret_user : all_users) { | |||||
| (void)mng->Replace(ret_user.first, node); | |||||
| } | |||||
| } else { | |||||
| // after eliminate, create new multi output. | |||||
| std::vector<AnfNodePtr> new_output_inputs{output_cnode->input(0)}; | |||||
| std::unordered_map<int, int> new_idx_map; | |||||
| for (auto idx : used_output_idx) { | |||||
| new_idx_map[idx] = SizeToInt(new_output_inputs.size() - 1); | |||||
| new_output_inputs.push_back(output_cnode->input(idx + 1)); | |||||
| } | |||||
| new_fg->set_output(new_fg->NewCNode(new_output_inputs)); | |||||
| // update users. | |||||
| for (auto &ret_user : all_users) { | |||||
| auto ret_user_cnode = ret_user.first->cast<CNodePtr>(); | |||||
| ret_user_cnode->set_input(2, NewValueNode(new_idx_map[ret_user.second])); | |||||
| } | |||||
| } | |||||
| auto new_sx = inputs; | |||||
| new_sx[0] = NewValueNode(new_fg); | |||||
| return node->func_graph()->NewCNode(new_sx); | |||||
| } | |||||
| }; | |||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -89,7 +89,7 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>; | |||||
| class Optimizer : public std::enable_shared_from_this<Optimizer> { | class Optimizer : public std::enable_shared_from_this<Optimizer> { | ||||
| public: | public: | ||||
| Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) | Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) | ||||
| : name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false) {} | |||||
| : name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false), is_enable_(true) {} | |||||
| virtual ~Optimizer() = default; | virtual ~Optimizer() = default; | ||||
| void Init(const OptPassGroupMap &passes, bool run_only_once) { | void Init(const OptPassGroupMap &passes, bool run_only_once) { | ||||
| @@ -132,6 +132,9 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||||
| } | } | ||||
| FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { | FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { | ||||
| if (!is_enable_) { | |||||
| return func_graph; | |||||
| } | |||||
| // Optimizer step counter; | // Optimizer step counter; | ||||
| int counter = -1; | int counter = -1; | ||||
| bool changes = true; | bool changes = true; | ||||
| @@ -171,7 +174,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||||
| }; | }; | ||||
| use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func(); | use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func(); | ||||
| if (IS_OUTPUT_ON(mindspore::DEBUG) && MsContext::GetInstance()->save_graphs_flag()) { | if (IS_OUTPUT_ON(mindspore::DEBUG) && MsContext::GetInstance()->save_graphs_flag()) { | ||||
| MS_LOG(DEBUG) << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; | |||||
| MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; | |||||
| auto fg_name = | auto fg_name = | ||||
| "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; | "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; | ||||
| func_graph->DumpFuncGraph(fg_name); | func_graph->DumpFuncGraph(fg_name); | ||||
| @@ -211,6 +214,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||||
| void enable_watch_renormalize() { is_watch_renormalize_ = true; } | void enable_watch_renormalize() { is_watch_renormalize_ = true; } | ||||
| void disable_watch_renormalize() { is_watch_renormalize_ = false; } | void disable_watch_renormalize() { is_watch_renormalize_ = false; } | ||||
| bool is_watch_renormalize() { return is_watch_renormalize_; } | bool is_watch_renormalize() { return is_watch_renormalize_; } | ||||
| void set_enable(bool enable) { is_enable_ = enable; } | |||||
| private: | private: | ||||
| const std::string name_; | const std::string name_; | ||||
| @@ -220,6 +224,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||||
| bool run_only_once_; | bool run_only_once_; | ||||
| std::vector<AnfNodePtr> untyped_nodes_; | std::vector<AnfNodePtr> untyped_nodes_; | ||||
| bool is_watch_renormalize_; | bool is_watch_renormalize_; | ||||
| bool is_enable_; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -64,7 +64,7 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti | |||||
| DumpGraph(root, std::string(ALLREDUCE_FUSION_END)); | DumpGraph(root, std::string(ALLREDUCE_FUSION_END)); | ||||
| // allreduce fusion only run once | // allreduce fusion only run once | ||||
| root->flags()[ALLREDUCE_FUSION_RUN_ONCE_ONLY] = true; | |||||
| root->set_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY, true); | |||||
| res->results()[pipeline::kStepParallelGraph] = root; | res->results()[pipeline::kStepParallelGraph] = root; | ||||
| #if defined(_WIN32) || defined(_WIN64) | #if defined(_WIN32) || defined(_WIN64) | ||||
| auto end_time = std::chrono::steady_clock::now(); | auto end_time = std::chrono::steady_clock::now(); | ||||
| @@ -158,8 +158,8 @@ void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(param_node); | MS_EXCEPTION_IF_NULL(param_node); | ||||
| MS_EXCEPTION_IF_NULL(ptr); | MS_EXCEPTION_IF_NULL(ptr); | ||||
| if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->flags().count(TRAINING) == 0) || | |||||
| func_graph->flags()[TRAINING]) { | |||||
| if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->attrs().count(TRAINING) == 0) || | |||||
| func_graph->has_flag(TRAINING)) { | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -107,7 +107,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | |||||
| time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec); | time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec); | ||||
| MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us"; | MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us"; | ||||
| root->flags()[AUTO_PARALLEL_RUN_ONCE_ONLY] = true; | |||||
| root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true); | |||||
| return changes; | return changes; | ||||
| } | } | ||||
| @@ -2270,10 +2270,10 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) | |||||
| (root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) { | (root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) { | ||||
| if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) { | if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) { | ||||
| if (HasStrategy(root)) { | if (HasStrategy(root)) { | ||||
| MS_LOG(INFO) << "strategies ignored in " << parallel_mode | |||||
| MS_LOG(INFO) << "Strategies ignored in " << parallel_mode | |||||
| << ", set_strategy() only valid in [semi_]auto_parallel."; | << ", set_strategy() only valid in [semi_]auto_parallel."; | ||||
| } | } | ||||
| root->flags()[CHECK_SET_STRATEGY_VALID_ONCE_ONLY] = true; | |||||
| root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true); | |||||
| } | } | ||||
| return changes; | return changes; | ||||
| @@ -2330,11 +2330,11 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) | |||||
| DumpGraph(root, std::string(STEP_PARALLEL_END)); | DumpGraph(root, std::string(STEP_PARALLEL_END)); | ||||
| // step parallel only run once | // step parallel only run once | ||||
| root->flags()[SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY] = true; | |||||
| root->set_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY, true); | |||||
| res->results()[pipeline::kStepParallelGraph] = root; | res->results()[pipeline::kStepParallelGraph] = root; | ||||
| // in auto parallel mode, no need to check if stategies set | // in auto parallel mode, no need to check if stategies set | ||||
| root->flags()[CHECK_SET_STRATEGY_VALID_ONCE_ONLY] = true; | |||||
| root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true); | |||||
| (void)gettimeofday(&end_time, nullptr); | (void)gettimeofday(&end_time, nullptr); | ||||
| uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec); | uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec); | ||||
| @@ -151,7 +151,10 @@ PYBIND11_MODULE(_c_expression, m) { | |||||
| .def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.") | .def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.") | ||||
| .def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.") | .def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.") | ||||
| .def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.") | .def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.") | ||||
| .def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print."); | |||||
| .def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.") | |||||
| .def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel, | |||||
| "Set the GraphKernel switch to on or off.") | |||||
| .def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch."); | |||||
| (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") | (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") | ||||
| .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") | .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") | ||||
| @@ -278,7 +278,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { | |||||
| if (bprop_graph != nullptr) { | if (bprop_graph != nullptr) { | ||||
| (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); | (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); | ||||
| (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); | (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); | ||||
| func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true); | |||||
| func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true); | |||||
| } | } | ||||
| } | } | ||||
| *data = func_graph; | *data = func_graph; | ||||
| @@ -1448,15 +1448,23 @@ bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) { | |||||
| } | } | ||||
| py::dict flags = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG); | py::dict flags = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG); | ||||
| for (auto &item : flags) { | for (auto &item : flags) { | ||||
| if (!py::isinstance<py::str>(item.first) || !py::isinstance<py::bool_>(item.second)) { | |||||
| if (!py::isinstance<py::str>(item.first)) { | |||||
| MS_LOG(ERROR) << "Type error in flags dict convert"; | MS_LOG(ERROR) << "Type error in flags dict convert"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto name = py::cast<std::string>(item.first); | auto name = py::cast<std::string>(item.first); | ||||
| auto value = py::cast<bool>(item.second); | |||||
| MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; | |||||
| func_graph->set_flags(name, value); | |||||
| if (py::isinstance<py::bool_>(item.second)) { | |||||
| auto value = py::cast<bool>(item.second); | |||||
| MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; | |||||
| func_graph->set_flag(name, value); | |||||
| } else if (py::isinstance<py::str>(item.second)) { | |||||
| auto value = py::cast<std::string>(item.second); | |||||
| MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; | |||||
| func_graph->set_attr(name, MakeValue(value)); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Type error in flags/attrs dict convert"; | |||||
| return false; | |||||
| } | |||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -223,8 +223,8 @@ class Parser { | |||||
| FunctionBlockPtr block = std::make_shared<FunctionBlock>(parse); | FunctionBlockPtr block = std::make_shared<FunctionBlock>(parse); | ||||
| // In order to keep effect order in the sub-graphs which generated by control flow. | // In order to keep effect order in the sub-graphs which generated by control flow. | ||||
| // We copy the flags from the top graph to the sub-graphs. | // We copy the flags from the top graph to the sub-graphs. | ||||
| if (func_graph_ && !func_graph_->flags().empty()) { | |||||
| block->func_graph()->set_flags(func_graph_->flags()); | |||||
| if (func_graph_ && !func_graph_->attrs().empty()) { | |||||
| block->func_graph()->set_attrs(func_graph_->attrs()); | |||||
| } | } | ||||
| func_block_list_.push_back(block); | func_block_list_.push_back(block); | ||||
| return block; | return block; | ||||
| @@ -25,12 +25,14 @@ | |||||
| #include <functional> | #include <functional> | ||||
| #include "ir/func_graph_cloner.h" | #include "ir/func_graph_cloner.h" | ||||
| #include "debug/anf_ir_utils.h" | |||||
| #include "pipeline/parse/parse_base.h" | #include "pipeline/parse/parse_base.h" | ||||
| #include "pipeline/parse/data_converter.h" | #include "pipeline/parse/data_converter.h" | ||||
| #include "pipeline/resource.h" | #include "pipeline/resource.h" | ||||
| #include "pipeline/validator.h" | #include "pipeline/validator.h" | ||||
| #include "optimizer/optimizer.h" | #include "optimizer/optimizer.h" | ||||
| #include "optimizer/cse.h" | #include "optimizer/cse.h" | ||||
| #include "optimizer/graph_kernel_reuse.h" | |||||
| #include "optimizer/clean.h" | #include "optimizer/clean.h" | ||||
| #include "optimizer/irpass.h" | #include "optimizer/irpass.h" | ||||
| #include "optimizer/control_depend.h" | #include "optimizer/control_depend.h" | ||||
| @@ -38,6 +40,7 @@ | |||||
| #include "parallel/step_auto_parallel.h" | #include "parallel/step_auto_parallel.h" | ||||
| #include "parallel/allreduce_fusion/step_allreduce_fusion.h" | #include "parallel/allreduce_fusion/step_allreduce_fusion.h" | ||||
| #include "utils/any.h" | #include "utils/any.h" | ||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace pipeline { | namespace pipeline { | ||||
| @@ -162,6 +165,40 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| return map; | return map; | ||||
| } | } | ||||
| OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| opt::OptPassConfig interface_fusion = opt::OptPassConfig({ | |||||
| irpass.mark_interface_fusion_, | |||||
| }); | |||||
| OptPassGroupMap map({ | |||||
| {"graph_kernel_reuse", opt::OptPassConfig(opt::GraphKernelReuse())}, | |||||
| {"interface_fusion", interface_fusion}, | |||||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | |||||
| {"cse", opt::OptPassConfig(opt::CSE(false))}, | |||||
| }); | |||||
| return map; | |||||
| } | |||||
| OptPassGroupMap GetOptPassesGraphKernelB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| opt::OptPassConfig elim_1 = opt::OptPassConfig({ | |||||
| irpass.addn_eliminate_, | |||||
| irpass.incorporate_getitem_from_param_, | |||||
| }); | |||||
| opt::OptPassConfig elim_2 = opt::OptPassConfig({ | |||||
| irpass.unused_parameter_eliminate_, | |||||
| irpass.unused_output_eliminate_, | |||||
| }); | |||||
| OptPassGroupMap map({ | |||||
| {"elim_1", elim_1}, | |||||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | |||||
| {"elim_2", elim_2}, | |||||
| }); | |||||
| return map; | |||||
| } | |||||
| OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}}); | |||||
| } | |||||
| OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { | OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { | ||||
| opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}, true); | opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}, true); | ||||
| OptPassGroupMap map({ | OptPassGroupMap map({ | ||||
| @@ -191,8 +228,19 @@ void InitOpt(const ResourcePtr &res) { | |||||
| opt::irpass::OptimizeIRPassLib irpass; | opt::irpass::OptimizeIRPassLib irpass; | ||||
| g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); | g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); | ||||
| g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); | g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); | ||||
| g_pass_opts["opt_graph_kernel_a"] = | |||||
| Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true); | |||||
| g_pass_opts["opt_graph_kernel_b"] = | |||||
| Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false); | |||||
| g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass)); | |||||
| g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true); | g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true); | ||||
| g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); | g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); | ||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (!(context_ptr->enable_graph_kernel())) { | |||||
| g_pass_opts["opt_graph_kernel_a"]->set_enable(false); | |||||
| g_pass_opts["opt_graph_kernel_b"]->set_enable(false); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -224,9 +272,13 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) { | |||||
| bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } | bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } | ||||
| bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } | bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } | ||||
| bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); } | |||||
| bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); } | |||||
| bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } | bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } | ||||
| bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } | bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } | ||||
| bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); } | |||||
| bool AddControlDependPass(const ResourcePtr &res) { | bool AddControlDependPass(const ResourcePtr &res) { | ||||
| FuncGraphPtr func_graph = res->func_graph(); | FuncGraphPtr func_graph = res->func_graph(); | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| @@ -270,8 +322,10 @@ bool InferenceOptPreparePass(const ResourcePtr &res) { | |||||
| std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, | std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, | ||||
| {"opt_a", OptPassAGroup}, | {"opt_a", OptPassAGroup}, | ||||
| {"opt_b", OptPassBGroup}, | {"opt_b", OptPassBGroup}, | ||||
| {"add_control_depend", AddControlDependPass}, | |||||
| {"cconv", CconvPass}}; | |||||
| {"cconv", CconvPass}, | |||||
| {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, | |||||
| {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, | |||||
| {"add_control_depend", AddControlDependPass}}; | |||||
| std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, | std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, | ||||
| {"opt_a", OptPassAGroup}, | {"opt_a", OptPassAGroup}, | ||||
| @@ -488,7 +488,7 @@ py::object ExecDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const | |||||
| #ifdef ENABLE_INFER | #ifdef ENABLE_INFER | ||||
| // Now don't use the graph because the exec ge function don't take effect | // Now don't use the graph because the exec ge function don't take effect | ||||
| MS_EXCEPTION_IF_NULL(info.at(phase)->func_graph); | MS_EXCEPTION_IF_NULL(info.at(phase)->func_graph); | ||||
| if (ENABLE_TRAIN != info.at(phase)->func_graph->flags()["training"]) { | |||||
| if (ENABLE_TRAIN != info.at(phase)->func_graph->has_flag("training")) { | |||||
| MS_LOG(ERROR) << "Graph training mode mismatch mode of libraries"; | MS_LOG(ERROR) << "Graph training mode mismatch mode of libraries"; | ||||
| ConfigManager::GetInstance().ResetConfig(); | ConfigManager::GetInstance().ResetConfig(); | ||||
| return py::none(); | return py::none(); | ||||
| @@ -165,7 +165,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa | |||||
| MS_LOG(DEBUG) << "Joined args: " << ::mindspore::ToString(joined_args_spec_list); | MS_LOG(DEBUG) << "Joined args: " << ::mindspore::ToString(joined_args_spec_list); | ||||
| // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. | // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. | ||||
| if (!(joined_args_spec_list == args_spec_list)) { | if (!(joined_args_spec_list == args_spec_list)) { | ||||
| func_graph_->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||||
| func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||||
| } | } | ||||
| return joined_args_spec_list; | return joined_args_spec_list; | ||||
| } | } | ||||
| @@ -178,7 +178,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa | |||||
| // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. | // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. | ||||
| if (!(joined_args_spec_list == args_spec_list)) { | if (!(joined_args_spec_list == args_spec_list)) { | ||||
| trace_.push_back(joined_args_spec_list); | trace_.push_back(joined_args_spec_list); | ||||
| func_graph_->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||||
| func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); | MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); | ||||
| return joined_args_spec_list; | return joined_args_spec_list; | ||||
| @@ -479,7 +479,7 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) { | |||||
| if (undetermined_fgs) { | if (undetermined_fgs) { | ||||
| auto fg_parent = fg->parent(); | auto fg_parent = fg->parent(); | ||||
| MS_EXCEPTION_IF_NULL(fg_parent); | MS_EXCEPTION_IF_NULL(fg_parent); | ||||
| fg_parent->set_flags(kFuncGraphFlagUndetermined, true); | |||||
| fg_parent->set_flag(kFuncGraphFlagUndetermined, true); | |||||
| MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); | MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include "pre_activate/ascend/ascend_backend_optimization.h" | #include "pre_activate/ascend/ascend_backend_optimization.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <set> | |||||
| #include "pre_activate/common/optimizer.h" | #include "pre_activate/common/optimizer.h" | ||||
| #include "pre_activate/ascend/ir_fission/bn_split.h" | #include "pre_activate/ascend/ir_fission/bn_split.h" | ||||
| #include "pre_activate/ascend/ir_fission/bn_grad_split.h" | #include "pre_activate/ascend/ir_fission/bn_grad_split.h" | ||||
| @@ -63,6 +64,9 @@ | |||||
| #include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" | #include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" | ||||
| #include "pre_activate/pass/eliminate_redundant_op.h" | #include "pre_activate/pass/eliminate_redundant_op.h" | ||||
| #include "pre_activate/pass/common_subexpression_elimination.h" | #include "pre_activate/pass/common_subexpression_elimination.h" | ||||
| #include "pre_activate/pass/fuse_graph_kernel.h" | |||||
| #include "pre_activate/pass/fuse_basic.h" | |||||
| #include "pre_activate/pass/add_atomic_clean.h" | |||||
| #include "pre_activate/ascend/format_type/merge_cast_to_op.h" | #include "pre_activate/ascend/format_type/merge_cast_to_op.h" | ||||
| #include "pre_activate/ascend/format_type/check_consistency.h" | #include "pre_activate/ascend/format_type/check_consistency.h" | ||||
| #include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h" | #include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h" | ||||
| @@ -88,6 +92,8 @@ | |||||
| #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" | #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" | ||||
| #include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h" | #include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h" | ||||
| #include "pre_activate/ascend/ir_fission/split_fission.h" | #include "pre_activate/ascend/ir_fission/split_fission.h" | ||||
| #include "pre_activate/ascend/format_type/modify_ops_attrs.h" | |||||
| #include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h" | |||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "utils/config_manager.h" | #include "utils/config_manager.h" | ||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| @@ -164,6 +170,19 @@ void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_g | |||||
| kernel_graph->SetExecOrderByDefault(); | kernel_graph->SetExecOrderByDefault(); | ||||
| } | } | ||||
| void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||||
| MS_EXCEPTION_IF_NULL(optimizer); | |||||
| auto common_process = std::make_shared<PassManager>("graph_kernel_common_process"); | |||||
| MS_EXCEPTION_IF_NULL(common_process); | |||||
| common_process->AddPass(std::make_shared<ModifyOpAttrs>()); | |||||
| common_process->AddPass(std::make_shared<RemoveNoUseReshapeOp>()); | |||||
| optimizer->AddPassManager(common_process); | |||||
| (void)optimizer->Optimize(kernel_graph); | |||||
| kernel_graph->SetExecOrderByDefault(); | |||||
| } | |||||
| void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | auto optimizer = std::make_shared<GraphOptimizer>(); | ||||
| @@ -332,7 +351,94 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern | |||||
| std::string file_path = | std::string file_path = | ||||
| save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | ||||
| DumpIR(file_path, kernel_graph, true); | DumpIR(file_path, kernel_graph, true); | ||||
| DumpIRProto(kernel_graph, "after_hwopt_" + std::to_string(kernel_graph->graph_id())); | |||||
| DumpIRProto(kernel_graph, "after_hwopt"); | |||||
| kernel_graph->DumpFuncGraph("hwopt_d_end"); | |||||
| } | |||||
| } | |||||
| void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph, | |||||
| bool is_before_kernel_select) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (!(context_ptr->enable_graph_kernel())) { | |||||
| return; | |||||
| } | |||||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||||
| if (save_graphs_path.empty()) { | |||||
| save_graphs_path = "."; | |||||
| } | |||||
| if (save_graphs) { | |||||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_before_graph_" + | |||||
| std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + | |||||
| ".ir"; | |||||
| DumpIR(file_path, kernel_graph); | |||||
| } | |||||
| // Fuse graph kernels with basic ops | |||||
| FuseGraphKernel(kernel_graph, is_before_kernel_select); | |||||
| if (save_graphs) { | |||||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_end_graph_" + | |||||
| std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + | |||||
| ".ir"; | |||||
| DumpIR(file_path, kernel_graph, true); | |||||
| } | |||||
| } | |||||
| void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph, | |||||
| bool is_before_kernel_select) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (!(context_ptr->enable_graph_kernel())) { | |||||
| return; | |||||
| } | |||||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||||
| if (save_graphs_path.empty()) { | |||||
| save_graphs_path = "."; | |||||
| } | |||||
| if (save_graphs) { | |||||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_before_graph_" + | |||||
| std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + | |||||
| ".ir"; | |||||
| DumpIR(file_path, kernel_graph, true); | |||||
| } | |||||
| // Fuse basic ops with basic ops | |||||
| FuseBasic(kernel_graph, is_before_kernel_select); | |||||
| if (save_graphs) { | |||||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_end_graph_" + | |||||
| std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + | |||||
| ".ir"; | |||||
| DumpIR(file_path, kernel_graph, true); | |||||
| } | |||||
| } | |||||
| void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (!(context_ptr->enable_graph_kernel())) { | |||||
| return; | |||||
| } | |||||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||||
| if (save_graphs_path.empty()) { | |||||
| save_graphs_path = "."; | |||||
| } | |||||
| if (save_graphs) { | |||||
| std::string file_path = save_graphs_path + "/" + "hwopt_d_add_atomic_clean_before" + "_graph_" + | |||||
| std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_path, kernel_graph); | |||||
| } | |||||
| AddAtomicClean(kernel_graph); | |||||
| if (save_graphs) { | |||||
| std::string file_path = | |||||
| save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; | |||||
| DumpIR(file_path, kernel_graph, true); | |||||
| } | } | ||||
| } | } | ||||
| @@ -24,6 +24,12 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne | |||||
| void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph); | void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph); | ||||
| void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph); | void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph); | ||||
| void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | ||||
| void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||||
| void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph, | |||||
| bool is_before_kernel_select = false); | |||||
| void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph, | |||||
| bool is_before_kernel_select = false); | |||||
| void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||||
| void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | ||||
| void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "device/kernel_info.h" | #include "device/kernel_info.h" | ||||
| #include "kernel/oplib/oplib.h" | #include "kernel/oplib/oplib.h" | ||||
| #include "kernel/common_utils.h" | |||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "session/kernel_graph.h" | #include "session/kernel_graph.h" | ||||
| @@ -229,7 +230,7 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr | |||||
| if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kTBE) != nullptr) { | if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kTBE) != nullptr) { | ||||
| builder.SetKernelType(KernelType::TBE_KERNEL); | builder.SetKernelType(KernelType::TBE_KERNEL); | ||||
| } else { | } else { | ||||
| builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); | |||||
| builder.SetKernelType(KernelType::AKG_KERNEL); | |||||
| } | } | ||||
| // if kernel info is null , it remarks this function is running ut | // if kernel info is null , it remarks this function is running ut | ||||
| if (cast->kernel_info() == nullptr) { | if (cast->kernel_info() == nullptr) { | ||||
| @@ -284,22 +285,17 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; | std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; | ||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | ||||
| TypeId origin_type; | |||||
| const auto infer_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); | |||||
| TypeId origin_type(kTypeUnknown); | |||||
| auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); | auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); | ||||
| auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); | auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); | ||||
| auto is_weight_boundary = [](const AnfNodePtr &node) -> bool { | |||||
| if (node->isa<ValueNode>()) { | |||||
| return true; | |||||
| } | |||||
| if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| }; | |||||
| auto real_input_node = kernel_with_index.first; | auto real_input_node = kernel_with_index.first; | ||||
| if (is_weight_boundary(real_input_node)) { | |||||
| if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||||
| // weight | // weight | ||||
| origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index); | |||||
| origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index); | |||||
| if (origin_type == kTypeUnknown) { | |||||
| origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index); | |||||
| } | |||||
| } else { | } else { | ||||
| // feature map | // feature map | ||||
| origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); | origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); | ||||
| @@ -307,9 +303,13 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod | |||||
| const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); | const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); | ||||
| const std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_index); | const std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_index); | ||||
| const TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); | const TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); | ||||
| if (origin_type != device_type) { | |||||
| // In graph kernel, we check parameter, | |||||
| // the eliminate pass will not eliminate this case, so we just do not insert the noused cast. | |||||
| if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) { | |||||
| new_inputs.push_back(cur_input); | |||||
| } else if (origin_type != device_type) { | |||||
| auto cast = | auto cast = | ||||
| AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, origin_type); | |||||
| AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type); | |||||
| MS_EXCEPTION_IF_NULL(cast); | MS_EXCEPTION_IF_NULL(cast); | ||||
| cast->set_scope(cnode->scope()); | cast->set_scope(cnode->scope()); | ||||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast); | AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast); | ||||
| @@ -17,9 +17,12 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | |||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "common/utils.h" | |||||
| #include "kernel/common_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -74,11 +77,21 @@ const AnfNodePtr CheckConsistency::Process(const FuncGraphPtr &, const AnfNodePt | |||||
| if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) { | if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) { | |||||
| if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) { | |||||
| MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(node) << "[" | |||||
| << node->DebugString() << "]"; | |||||
| std::vector<AnfNodePtr> todos = {node}; | |||||
| if (AnfAlgo::IsGraphKernel(node)) { | |||||
| auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||||
| MS_EXCEPTION_IF_NULL(sub_graph); | |||||
| kernel::GetValidKernelNodes(sub_graph, &todos); | |||||
| } | |||||
| for (auto &t : todos) { | |||||
| CNodePtr cnode = t->cast<CNodePtr>(); | |||||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) { | |||||
| if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) { | |||||
| MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "[" | |||||
| << cnode->DebugString() << "]"; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <utility> | |||||
| #include "device/kernel_info.h" | #include "device/kernel_info.h" | ||||
| #include "pre_activate/ascend/ascend_helper.h" | #include "pre_activate/ascend/ascend_helper.h" | ||||
| @@ -27,34 +28,45 @@ | |||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "session/kernel_graph.h" | #include "session/kernel_graph.h" | ||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "kernel/common_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||||
| AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||||
| const std::vector<bool> &need_insert_cast) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| std::vector<AnfNodePtr> make_tuple_inputs; | std::vector<AnfNodePtr> make_tuple_inputs; | ||||
| AbstractBasePtrList abstract_list; | AbstractBasePtrList abstract_list; | ||||
| make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | ||||
| for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(cnode); ++output_idx) { | for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(cnode); ++output_idx) { | ||||
| const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx); | |||||
| const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); | |||||
| const TypeId origin_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx); | |||||
| const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx); | |||||
| AnfNodePtr replace_node = nullptr; | |||||
| const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); | |||||
| const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx); | |||||
| auto idx = NewValueNode(SizeToInt(output_idx)); | auto idx = NewValueNode(SizeToInt(output_idx)); | ||||
| MS_EXCEPTION_IF_NULL(idx); | MS_EXCEPTION_IF_NULL(idx); | ||||
| auto imm = std::make_shared<Int32Imm>(output_idx); | auto imm = std::make_shared<Int32Imm>(output_idx); | ||||
| idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm)); | idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm)); | ||||
| auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx}); | auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx}); | ||||
| AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, getitem.get()); | |||||
| AnfNodePtr replace_node = nullptr; | |||||
| if (origin_type != device_type) { | |||||
| replace_node = | |||||
| AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape, origin_type); | |||||
| MS_EXCEPTION_IF_NULL(replace_node); | |||||
| replace_node->set_scope(cnode->scope()); | |||||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); | |||||
| AnfAlgo::SetOutputInferTypeAndShape({infer_type}, {origin_shape}, getitem.get()); | |||||
| if (need_insert_cast[output_idx]) { | |||||
| const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx); | |||||
| TypeId origin_type(kTypeUnknown); | |||||
| if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||||
| origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode); | |||||
| } | |||||
| origin_type = origin_type == kTypeUnknown ? infer_type : origin_type; | |||||
| const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx); | |||||
| if (origin_type != device_type) { | |||||
| replace_node = | |||||
| AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape, infer_type); | |||||
| MS_EXCEPTION_IF_NULL(replace_node); | |||||
| replace_node->set_scope(cnode->scope()); | |||||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); | |||||
| } else { | |||||
| replace_node = getitem; | |||||
| } | |||||
| } else { | } else { | ||||
| replace_node = getitem; | replace_node = getitem; | ||||
| } | } | ||||
| @@ -65,9 +77,10 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo | |||||
| MS_EXCEPTION_IF_NULL(make_tuple); | MS_EXCEPTION_IF_NULL(make_tuple); | ||||
| make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | ||||
| return make_tuple; | return make_tuple; | ||||
| } | |||||
| } // namespace | |||||
| AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||||
| AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||||
| const std::vector<bool> &need_insert_cast) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| if (AnfAlgo::GetOutputTensorNum(cnode) == 0) { | if (AnfAlgo::GetOutputTensorNum(cnode) == 0) { | ||||
| @@ -76,14 +89,23 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c | |||||
| MS_EXCEPTION_IF_NULL(cnode->Type()); | MS_EXCEPTION_IF_NULL(cnode->Type()); | ||||
| // Single output | // Single output | ||||
| if (!cnode->Type()->isa<Tuple>()) { | if (!cnode->Type()->isa<Tuple>()) { | ||||
| if (!need_insert_cast[0]) { | |||||
| return cnode; | |||||
| } | |||||
| const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0); | const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0); | ||||
| std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0); | std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0); | ||||
| const TypeId origin_type = AnfAlgo::GetOutputInferDataType(cnode, 0); | |||||
| const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, 0); | |||||
| TypeId origin_type(kTypeUnknown); | |||||
| if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||||
| origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode); | |||||
| } | |||||
| origin_type = origin_type == kTypeUnknown ? infer_type : origin_type; | |||||
| const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0); | const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0); | ||||
| AnfNodePtr replace_node = cnode; | AnfNodePtr replace_node = cnode; | ||||
| if (origin_type != device_type) { | if (origin_type != device_type) { | ||||
| replace_node = | replace_node = | ||||
| AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape, origin_type); | |||||
| AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape, infer_type); | |||||
| MS_EXCEPTION_IF_NULL(replace_node); | MS_EXCEPTION_IF_NULL(replace_node); | ||||
| replace_node->set_scope(cnode->scope()); | replace_node->set_scope(cnode->scope()); | ||||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); | AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); | ||||
| @@ -91,7 +113,57 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c | |||||
| return replace_node; | return replace_node; | ||||
| } | } | ||||
| // Multiple output | // Multiple output | ||||
| return InsertCastForMultipleOutput(func_graph, cnode); | |||||
| return InsertCastForMultipleOutput(func_graph, cnode, need_insert_cast); | |||||
| } | |||||
| AnfNodePtr ProcessGraphKernelOp(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||||
| // insert cast for ops in graph kernel. | |||||
| auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||||
| MS_EXCEPTION_IF_NULL(sub_graph); | |||||
| auto mng = sub_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(mng); | |||||
| std::vector<AnfNodePtr> todo; | |||||
| std::vector<std::pair<AnfNodePtr, size_t>> graph_rets; | |||||
| kernel::GetValidKernelNodes(sub_graph, &todo); | |||||
| kernel::GetGraphRealOutput(sub_graph, &graph_rets); | |||||
| for (auto &t : todo) { | |||||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), t); | |||||
| // process input | |||||
| CNodePtr t_cnode = t->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(t_cnode); | |||||
| auto t_new_node = InsertCastForInput(sub_graph, t_cnode); | |||||
| AnfNodePtr t_new_node_1 = nullptr; | |||||
| std::vector<bool> need_insert_cast(AnfAlgo::GetOutputTensorNum(t), true); | |||||
| // process output | |||||
| auto iter = std::find_if(graph_rets.begin(), graph_rets.end(), | |||||
| [&t](const std::pair<AnfNodePtr, size_t> &ret) { return ret.first == t; }); | |||||
| if (iter != graph_rets.end()) { | |||||
| auto t_fix_output_type = AnfAlgo::GetCNodeOutputPrecision(t); | |||||
| auto t_output_type = AnfAlgo::GetOutputDeviceDataType(t, iter->second); | |||||
| auto graph_output_type = AnfAlgo::GetOutputDeviceDataType(node, iter - graph_rets.begin()); | |||||
| if (t_fix_output_type == kTypeUnknown && t_output_type == graph_output_type) { | |||||
| need_insert_cast[iter->second] = false; | |||||
| } else if (t_fix_output_type == t_output_type && t_output_type == graph_output_type) { | |||||
| need_insert_cast[iter->second] = false; | |||||
| } | |||||
| t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast); | |||||
| } else { | |||||
| t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast); | |||||
| } | |||||
| if (t_new_node_1 != nullptr && t_new_node_1 != t) { | |||||
| (void)mng->Replace(t, t_new_node_1); | |||||
| } | |||||
| } | |||||
| // insert cast for graph kernel. | |||||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||||
| // process input | |||||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto new_node = InsertCastForInput(func_graph, cnode); | |||||
| // process output | |||||
| return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true)); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -106,13 +178,27 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo | |||||
| if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) { | if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (AnfAlgo::IsGraphKernel(node)) { | |||||
| return ProcessGraphKernelOp(func_graph, node); | |||||
| } else { | |||||
| // insert cast for single op. | |||||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||||
| // process input | |||||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto new_node = InsertCastForInput(func_graph, cnode); | |||||
| // process output | |||||
| return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true)); | |||||
| } | |||||
| // insert cast for single op. | |||||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | ||||
| // process input | // process input | ||||
| CNodePtr cnode = node->cast<CNodePtr>(); | CNodePtr cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| auto new_node = InsertCastForInput(func_graph, cnode); | auto new_node = InsertCastForInput(func_graph, cnode); | ||||
| // process output | // process output | ||||
| return InsertCastForOutput(func_graph, new_node); | |||||
| return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true)); | |||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -133,6 +133,9 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto next_cnode = next_node->cast<CNodePtr>(); | auto next_cnode = next_node->cast<CNodePtr>(); | ||||
| if (AnfAlgo::IsGraphKernel(next_node)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto next_op_name = AnfAlgo::GetCNodeName(next_node); | auto next_op_name = AnfAlgo::GetCNodeName(next_node); | ||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | ||||
| kernel_query->Query(next_cnode, &kernel_info_list); | kernel_query->Query(next_cnode, &kernel_info_list); | ||||
| @@ -206,6 +209,9 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(prior_op); | MS_EXCEPTION_IF_NULL(prior_op); | ||||
| if (AnfAlgo::IsGraphKernel(prior_op)) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | ||||
| kernel_query->Query(prior_op, &kernel_info_list); | kernel_query->Query(prior_op, &kernel_info_list); | ||||
| @@ -0,0 +1,99 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/format_type/modify_ops_attrs.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "utils/utils.h" | |||||
| #include "pre_activate/common/helper.h" | |||||
| #include "kernel/common_utils.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "operator/ops.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| AnfNodePtr ModifyReduceOpsAttrs(const CNodePtr &cnode) { | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); | |||||
| auto input_format = AnfAlgo::GetInputFormat(cnode, 0); | |||||
| if (input_shape.size() == 5 || input_format != kOpFormat_NC1HWC0) { | |||||
| return nullptr; | |||||
| } | |||||
| if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) { | |||||
| return nullptr; | |||||
| } | |||||
| AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), cnode); | |||||
| return cnode; | |||||
| } | |||||
| AnfNodePtr ModifyTileOpAttrs(const CNodePtr &cnode) { | |||||
| auto input_shape = AnfAlgo::GetInputDeviceShape(cnode, 0); | |||||
| if (input_shape.size() != 5) { | |||||
| return nullptr; | |||||
| } | |||||
| if (!AnfAlgo::HasNodeAttr(kAttrMultiples, cnode)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto multiples = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrMultiples); | |||||
| if (multiples.size() == 4 && multiples[1] == 1) { | |||||
| multiples.push_back(1); | |||||
| AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), cnode); | |||||
| } | |||||
| return cnode; | |||||
| } | |||||
| AnfNodePtr ModifyAttrs(const CNodePtr &cnode) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||||
| if (op_name == prim::kPrimTile->name()) { | |||||
| return ModifyTileOpAttrs(cnode); | |||||
| } else if (op_name == prim::kPrimReduceSum->name()) { | |||||
| // kPrimReduceMean | |||||
| // kPrimReduceSum | |||||
| // kPrimReduceAll | |||||
| // kPrimReduceMax | |||||
| // kPrimReduceMin | |||||
| return ModifyReduceOpsAttrs(cnode); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| } // namespace | |||||
| const AnfNodePtr ModifyOpAttrs::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsGraphKernel(node)) { | |||||
| return nullptr; | |||||
| } | |||||
| MS_LOG(DEBUG) << "====Process op: " << AnfAlgo::GetCNodeName(node); | |||||
| auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||||
| MS_EXCEPTION_IF_NULL(fg); | |||||
| auto manager = fg->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| std::vector<AnfNodePtr> todos; | |||||
| kernel::GetValidKernelNodes(fg, &todos); | |||||
| for (auto &t : todos) { | |||||
| auto new_node = ModifyAttrs(t->cast<CNodePtr>()); | |||||
| if (new_node != nullptr && new_node != t) { | |||||
| (void)manager->Replace(t, new_node); | |||||
| } | |||||
| } | |||||
| return node; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class ModifyOpAttrs : public PatternProcessPass { | |||||
| public: | |||||
| explicit ModifyOpAttrs(bool multigraph = true) : PatternProcessPass("modify_ops_attrs", multigraph) {} | |||||
| ~ModifyOpAttrs() override = default; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H | |||||
| @@ -0,0 +1,66 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "pre_activate/common/helper.h" | |||||
| #include "kernel/common_utils.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "operator/ops.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| AnfNodePtr RemoveReshapeOp(const CNodePtr &cnode) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||||
| if (op_name != prim::kPrimReshape->name()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); | |||||
| auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0); | |||||
| if (input_shape.size() != 1 || input_format != kOpFormat_NC1HWC0) { | |||||
| return nullptr; | |||||
| } | |||||
| return cnode->input(1); | |||||
| } | |||||
| } // namespace | |||||
| const AnfNodePtr RemoveNoUseReshapeOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsGraphKernel(node)) { | |||||
| return nullptr; | |||||
| } | |||||
| MS_LOG(DEBUG) << "====process op: " << AnfAlgo::GetCNodeName(node); | |||||
| auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||||
| MS_EXCEPTION_IF_NULL(fg); | |||||
| auto manager = fg->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| std::vector<AnfNodePtr> todos; | |||||
| kernel::GetValidKernelNodes(fg, &todos); | |||||
| for (auto &t : todos) { | |||||
| auto new_node = RemoveReshapeOp(t->cast<CNodePtr>()); | |||||
| if (new_node != nullptr && new_node != t) { | |||||
| (void)manager->Replace(t, new_node); | |||||
| } | |||||
| } | |||||
| return node; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class RemoveNoUseReshapeOp : public PatternProcessPass { | |||||
| public: | |||||
| explicit RemoveNoUseReshapeOp(bool multigraph = true) : PatternProcessPass("remove_no_use_reshape_op", multigraph) {} | |||||
| ~RemoveNoUseReshapeOp() override = default; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H | |||||
| @@ -121,6 +121,9 @@ const AnfNodePtr LayerNormBetaGammaBackpropFusion::Process(const FuncGraphPtr &f | |||||
| if (node == nullptr || !node->isa<CNode>()) { | if (node == nullptr || !node->isa<CNode>()) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (AnfAlgo::IsGraphKernel(node)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| std::vector<CNodePtr> cast_nodes; | std::vector<CNodePtr> cast_nodes; | ||||
| @@ -102,9 +102,12 @@ bool UnVisited(const BaseRef &n) { | |||||
| auto prim_py = value->cast<PrimitivePtr>(); | auto prim_py = value->cast<PrimitivePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(prim_py); | MS_EXCEPTION_IF_NULL(prim_py); | ||||
| return !prim_py->HasAttr(kAttrVisited); | return !prim_py->HasAttr(kAttrVisited); | ||||
| } else { | |||||
| return false; | |||||
| } else if (IsValueNode<FuncGraph>(in)) { | |||||
| auto func_graph = GetValueNode<FuncGraphPtr>(in); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| return !func_graph->has_flag(kAttrVisited); | |||||
| } | } | ||||
| return false; | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -188,9 +191,12 @@ bool Visited(const BaseRef &n) { | |||||
| auto prim_py = value->cast<PrimitivePtr>(); | auto prim_py = value->cast<PrimitivePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(prim_py); | MS_EXCEPTION_IF_NULL(prim_py); | ||||
| return prim_py->HasAttr(kAttrVisited); | return prim_py->HasAttr(kAttrVisited); | ||||
| } else { | |||||
| return false; | |||||
| } else if (IsValueNode<FuncGraph>(in)) { | |||||
| auto func_graph = GetValueNode<FuncGraphPtr>(in); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| return func_graph->has_flag(kAttrVisited); | |||||
| } | } | ||||
| return false; | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||