|
|
|
@@ -19,6 +19,7 @@ |
|
|
|
#include <unordered_map> |
|
|
|
#include <memory> |
|
|
|
#include <map> |
|
|
|
#include <fstream> |
|
|
|
#include "utils/log_adapter.h" |
|
|
|
#include "utils/overload.h" |
|
|
|
#include "utils/context/ms_context.h" |
|
|
|
@@ -59,7 +60,7 @@ constexpr auto kNeedCompile = "need_compile"; |
|
|
|
constexpr auto kShape = "shape"; |
|
|
|
std::vector<std::shared_ptr<OpInfo>> OpLib::op_info_; |
|
|
|
|
|
|
|
std::string ImplTypeToStr(OpImplyType impl_type) { |
|
|
|
static std::string ImplTypeToStr(OpImplyType impl_type) { |
|
|
|
switch (impl_type) { |
|
|
|
case kTBE: |
|
|
|
return kTbe; |
|
|
|
@@ -124,6 +125,50 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool OpLib::RegOpFromLocalInfo() { |
|
|
|
MS_LOG(INFO) << "Start"; |
|
|
|
static bool has_load = false; |
|
|
|
if (has_load) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
has_load = true; |
|
|
|
std::string dir = common::GetEnv("MINDSPORE_OP_INFO_PATH"); |
|
|
|
if (dir.empty()) { |
|
|
|
MS_LOG(INFO) << "MindSpore op info path does not been setted. use op info from python pass."; |
|
|
|
return true; |
|
|
|
} |
|
|
|
char real_path[PATH_MAX] = {0}; |
|
|
|
if (dir.size() >= PATH_MAX) { |
|
|
|
MS_LOG(ERROR) << "Op info path is invalid: " << dir; |
|
|
|
return false; |
|
|
|
} |
|
|
|
#if defined(_WIN32) || defined(_WIN64) |
|
|
|
if (_fullpath(real_path, common::SafeCStr(dir), PATH_MAX) == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Op info path is invalid: " << dir; |
|
|
|
return false; |
|
|
|
} |
|
|
|
#else |
|
|
|
if (realpath(common::SafeCStr(dir), real_path) == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Op info path is invalid: " << dir; |
|
|
|
return false; |
|
|
|
} |
|
|
|
#endif |
|
|
|
MS_LOG(INFO) << "Start to read op info from local file."; |
|
|
|
std::ifstream file(real_path); |
|
|
|
if (!file.is_open()) { |
|
|
|
MS_LOG(ERROR) << "Find op info file failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
std::string line; |
|
|
|
while (getline(file, line)) { |
|
|
|
if (!line.empty()) { |
|
|
|
(void)OpLib::RegOp(line, ""); |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "End"; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type, |
|
|
|
const std::string &impl_path) { |
|
|
|
std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>(); |
|
|
|
@@ -160,14 +205,16 @@ bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpI |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
if (CheckRepetition(op_info)) { |
|
|
|
MS_LOG(WARNING) << "This op info has been already registed. op name: " << op_info->op_name() |
|
|
|
<< ", impl type: " << ImplTypeToStr(op_info->imply_type()) |
|
|
|
<< ", impl path: " << op_info->impl_path(); |
|
|
|
return true; |
|
|
|
} |
|
|
|
if (!GetRefInfo(op_info)) { |
|
|
|
MS_LOG(ERROR) << "GetRefInfo Failed"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (!CheckRepetition(op_info)) { |
|
|
|
MS_LOG(ERROR) << "CheckRepetition Failed"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
op_info_.push_back(op_info); |
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -269,6 +316,9 @@ bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) { |
|
|
|
if (!OpLib::RegOpFromLocalInfo()) { |
|
|
|
MS_LOG(INFO) << "Warning reg local op info failed."; |
|
|
|
} |
|
|
|
auto context = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context); |
|
|
|
bool is_gpu = (context->device_target() == kGPUDevice); |
|
|
|
@@ -283,8 +333,8 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im |
|
|
|
return op_info; |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) |
|
|
|
<< ", current op num: " << op_info_.size(); |
|
|
|
MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) |
|
|
|
<< ", current op num: " << op_info_.size(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -313,17 +363,19 @@ bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo> &op_info) { |
|
|
|
} |
|
|
|
|
|
|
|
bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo> &op_info) { |
|
|
|
bool has_register = false; |
|
|
|
MS_EXCEPTION_IF_NULL(op_info); |
|
|
|
for (const auto &exist_op_info : op_info_) { |
|
|
|
MS_EXCEPTION_IF_NULL(exist_op_info); |
|
|
|
if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() && |
|
|
|
exist_op_info->impl_path() != op_info->impl_path()) { |
|
|
|
MS_LOG(ERROR) << "Op has already exist, please use other name, op name: " << op_info->op_name() |
|
|
|
<< " op type: " << ImplTypeToStr(op_info->imply_type()); |
|
|
|
return false; |
|
|
|
exist_op_info->impl_path() == op_info->impl_path()) { |
|
|
|
MS_LOG(INFO) << "Op has already exist, please use other name, op name: " << op_info->op_name() |
|
|
|
<< " op type: " << ImplTypeToStr(op_info->imply_type()); |
|
|
|
has_register = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
return has_register; |
|
|
|
} |
|
|
|
} // namespace kernel |
|
|
|
} // namespace mindspore |