Browse Source

fusion op insert cache

tags/v1.1.0
jjfeing 5 years ago
parent
commit
a8366502e2
1 changed files with 26 additions and 6 deletions
  1. +26
    -6
      mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.cc

+ 26
- 6
mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.cc View File

@@ -25,9 +25,30 @@
#include "backend/kernel_compiler/tbe/tbe_convert_utils.h" #include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"


namespace mindspore {
namespace kernel {
namespace mindspore::kernel {
using mindspore::kernel::tbe::TbeUtils; using mindspore::kernel::tbe::TbeUtils;

static size_t GenFusionJsonHash(const nlohmann::json &fusion_json) {
// get an copy
nlohmann::json fusion_json_copy = fusion_json;
auto &op_lists = fusion_json_copy["op_list"];
for (auto &op : op_lists) {
op.erase("name");
for (auto &output_desc : op["output_desc"]) {
output_desc.erase("name");
}
if (op["type"] != "Data") {
for (auto &input_desc : op["input_desc"]) {
input_desc.erase("name");
}
for (auto &list_arg : op["prebuild_output_attrs"]["list_args"]) {
list_arg.erase("name");
}
}
}
return std::hash<std::string>()(fusion_json_copy.dump());
}

std::map<int64_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo> &fusion_scopes) { std::map<int64_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo> &fusion_scopes) {
MS_LOG(INFO) << "kernel fusion build start, scope size:" << fusion_scopes.size(); MS_LOG(INFO) << "kernel fusion build start, scope size:" << fusion_scopes.size();
std::map<int64_t, KernelModPtr> kernel_mod_ret; std::map<int64_t, KernelModPtr> kernel_mod_ret;
@@ -41,8 +62,8 @@ std::map<int64_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo>
continue; continue;
} }
// gen kernel_name & check cache // gen kernel_name & check cache
std::string json_str = fusion_op.dump();
size_t hash_id = std::hash<std::string>()(json_str);
size_t hash_id = GenFusionJsonHash(fusion_op);
MS_LOG(INFO) << "Fusion op hash id: " << hash_id;
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
auto device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); auto device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
@@ -102,5 +123,4 @@ std::map<int64_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo>
MS_LOG(INFO) << "Build Fusion Kernel Failed Num: " << build_failed_num; MS_LOG(INFO) << "Build Fusion Kernel Failed Num: " << build_failed_num;
return kernel_mod_ret; return kernel_mod_ret;
} }
} // namespace kernel
} // namespace mindspore
} // namespace mindspore::kernel

Loading…
Cancel
Save