| @@ -25,9 +25,30 @@ | |||||
| #include "backend/kernel_compiler/tbe/tbe_convert_utils.h" | #include "backend/kernel_compiler/tbe/tbe_convert_utils.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| namespace mindspore::kernel { | |||||
| using mindspore::kernel::tbe::TbeUtils; | using mindspore::kernel::tbe::TbeUtils; | ||||
| static size_t GenFusionJsonHash(const nlohmann::json &fusion_json) { | |||||
| // get an copy | |||||
| nlohmann::json fusion_json_copy = fusion_json; | |||||
| auto &op_lists = fusion_json_copy["op_list"]; | |||||
| for (auto &op : op_lists) { | |||||
| op.erase("name"); | |||||
| for (auto &output_desc : op["output_desc"]) { | |||||
| output_desc.erase("name"); | |||||
| } | |||||
| if (op["type"] != "Data") { | |||||
| for (auto &input_desc : op["input_desc"]) { | |||||
| input_desc.erase("name"); | |||||
| } | |||||
| for (auto &list_arg : op["prebuild_output_attrs"]["list_args"]) { | |||||
| list_arg.erase("name"); | |||||
| } | |||||
| } | |||||
| } | |||||
| return std::hash<std::string>()(fusion_json_copy.dump()); | |||||
| } | |||||
| std::map<int64_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo> &fusion_scopes) { | std::map<int64_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo> &fusion_scopes) { | ||||
| MS_LOG(INFO) << "kernel fusion build start, scope size:" << fusion_scopes.size(); | MS_LOG(INFO) << "kernel fusion build start, scope size:" << fusion_scopes.size(); | ||||
| std::map<int64_t, KernelModPtr> kernel_mod_ret; | std::map<int64_t, KernelModPtr> kernel_mod_ret; | ||||
| @@ -41,8 +62,8 @@ std::map<int64_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo> | |||||
| continue; | continue; | ||||
| } | } | ||||
| // gen kernel_name & check cache | // gen kernel_name & check cache | ||||
| std::string json_str = fusion_op.dump(); | |||||
| size_t hash_id = std::hash<std::string>()(json_str); | |||||
| size_t hash_id = GenFusionJsonHash(fusion_op); | |||||
| MS_LOG(INFO) << "Fusion op hash id: " << hash_id; | |||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| auto device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); | auto device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); | ||||
| @@ -102,5 +123,4 @@ std::map<int64_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo> | |||||
| MS_LOG(INFO) << "Build Fusion Kernel Failed Num: " << build_failed_num; | MS_LOG(INFO) << "Build Fusion Kernel Failed Num: " << build_failed_num; | ||||
| return kernel_mod_ret; | return kernel_mod_ret; | ||||
| } | } | ||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::kernel | |||||