|
|
|
@@ -25,9 +25,30 @@ |
|
|
|
#include "backend/kernel_compiler/tbe/tbe_convert_utils.h" |
|
|
|
#include "utils/ms_context.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace kernel { |
|
|
|
namespace mindspore::kernel { |
|
|
|
using mindspore::kernel::tbe::TbeUtils; |
|
|
|
|
|
|
|
static size_t GenFusionJsonHash(const nlohmann::json &fusion_json) { |
|
|
|
// get an copy |
|
|
|
nlohmann::json fusion_json_copy = fusion_json; |
|
|
|
auto &op_lists = fusion_json_copy["op_list"]; |
|
|
|
for (auto &op : op_lists) { |
|
|
|
op.erase("name"); |
|
|
|
for (auto &output_desc : op["output_desc"]) { |
|
|
|
output_desc.erase("name"); |
|
|
|
} |
|
|
|
if (op["type"] != "Data") { |
|
|
|
for (auto &input_desc : op["input_desc"]) { |
|
|
|
input_desc.erase("name"); |
|
|
|
} |
|
|
|
for (auto &list_arg : op["prebuild_output_attrs"]["list_args"]) { |
|
|
|
list_arg.erase("name"); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return std::hash<std::string>()(fusion_json_copy.dump()); |
|
|
|
} |
|
|
|
|
|
|
|
std::map<int64_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo> &fusion_scopes) { |
|
|
|
MS_LOG(INFO) << "kernel fusion build start, scope size:" << fusion_scopes.size(); |
|
|
|
std::map<int64_t, KernelModPtr> kernel_mod_ret; |
|
|
|
@@ -41,8 +62,8 @@ std::map<int64_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo> |
|
|
|
continue; |
|
|
|
} |
|
|
|
// gen kernel_name & check cache |
|
|
|
std::string json_str = fusion_op.dump(); |
|
|
|
size_t hash_id = std::hash<std::string>()(json_str); |
|
|
|
size_t hash_id = GenFusionJsonHash(fusion_op); |
|
|
|
MS_LOG(INFO) << "Fusion op hash id: " << hash_id; |
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
auto device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); |
|
|
|
@@ -102,5 +123,4 @@ std::map<int64_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo> |
|
|
|
MS_LOG(INFO) << "Build Fusion Kernel Failed Num: " << build_failed_num; |
|
|
|
return kernel_mod_ret; |
|
|
|
} |
|
|
|
} // namespace kernel |
|
|
|
} // namespace mindspore |
|
|
|
} // namespace mindspore::kernel |