Browse Source

!395 Fix oplib coddex

Merge pull request !395 from zjun/fix_oplib_codexx
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 6 years ago
parent
commit
3c317c71f4
4 changed files with 21 additions and 25 deletions
  1. +2
    -11
      mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc
  2. +15
    -11
      mindspore/ccsrc/kernel/oplib/oplib.cc
  3. +1
    -0
      mindspore/ccsrc/kernel/oplib/oplib.h
  4. +3
    -3
      mindspore/ops/op_info_register.py

+ 2
- 11
mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc View File

@@ -39,8 +39,6 @@ namespace mindspore {
namespace kernel {
using FNodeAttrHandle = std::function<void(const std::shared_ptr<AnfNode> &anf_node, mindspore::NodeDef *proto)>;

const std::vector<std::string> local_framework_op_vec = {kInitData, kGetNext, kDropoutGenMask, kPrint};

bool SetIOIputSize(const std::shared_ptr<AnfNode> &anf_node, const size_t &input_num,
std::vector<size_t> *input_size_list) {
MS_EXCEPTION_IF_NULL(anf_node);
@@ -298,19 +296,12 @@ KernelModPtr AicpuOpBuild(const std::shared_ptr<AnfNode> &anf_node) {
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
kernel_mod_ptr->SetAnfNode(anf_node);
kernel_mod_ptr->SetNodeName(op_name);
auto iter = std::find(local_framework_op_vec.begin(), local_framework_op_vec.end(), op_name);
if (iter != local_framework_op_vec.end()) {
if (!CreateNodeDefBytes(anf_node, kernel_mod_ptr)) {
MS_LOG(EXCEPTION) << "Create nodeDefBytes faild!";
}
} else {
MS_LOG(EXCEPTION) << "Aicpu don't support node [" << op_name << "]";
if (!CreateNodeDefBytes(anf_node, kernel_mod_ptr)) {
MS_LOG(EXCEPTION) << "Create nodeDefBytes faild!";
}

if (!SetIOSize(anf_node, kernel_mod_ptr)) {
MS_LOG(EXCEPTION) << "Set input output size list failed.";
}

return kernel_mod_ptr;
}
} // namespace kernel


+ 15
- 11
mindspore/ccsrc/kernel/oplib/oplib.cc View File

@@ -94,6 +94,20 @@ bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path)
return ret;
}

void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr<OpInfo>& op_info) {
op_info->set_async_flag(obj.at(kAsyncFlag));
op_info->set_binfile_name(obj.at(kBinfileName));
op_info->set_compute_cost(obj.at(kComputeCost));
op_info->set_kernel_name(obj.at(kKernelName));
op_info->set_partial_flag(obj.at(kPartialFlag));
if (obj.find(kOpPattern) != obj.end()) {
op_info->set_op_pattern(obj.at(kOpPattern));
}
if (obj.find(kDynamicFormat) != obj.end()) {
op_info->set_dynamic_format(obj.at(kDynamicFormat));
}
}

bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpImplyType imply_type,
const std::string& impl_path) {
std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>();
@@ -103,17 +117,7 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
op_info->set_imply_type(imply_type);
op_info->set_fusion_type(obj.at(kFusionType));
if (imply_type == kTBE) {
op_info->set_async_flag(obj.at(kAsyncFlag));
op_info->set_binfile_name(obj.at(kBinfileName));
op_info->set_compute_cost(obj.at(kComputeCost));
op_info->set_kernel_name(obj.at(kKernelName));
op_info->set_partial_flag(obj.at(kPartialFlag));
if (obj.find(kOpPattern) != obj.end()) {
op_info->set_op_pattern(obj.at(kOpPattern));
}
if (obj.find(kDynamicFormat) != obj.end()) {
op_info->set_dynamic_format(obj.at(kDynamicFormat));
}
DecodeTBESpecificInfo(obj, op_info);
}
auto attrs = obj.at(kAttr);
for (const auto& attr : attrs) {


+ 1
- 0
mindspore/ccsrc/kernel/oplib/oplib.h View File

@@ -40,6 +40,7 @@ class OpLib {
const std::shared_ptr<OpInfo>& op_info);
static bool DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io,
size_t index);
static void DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr<OpInfo>& op_info);
static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type,
const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format);
static bool GetRefInfo(const std::shared_ptr<OpInfo>& op_info);


+ 3
- 3
mindspore/ops/op_info_register.py View File

@@ -57,7 +57,7 @@ def op_info_register(op_info):
return register_decorator


class RegOp():
class RegOp:
"""
Base class for op info register.

@@ -483,9 +483,9 @@ class TBERegOp(RegOp):
return self


class DataType():
class DataType:
"""
Various combinations of dtype and formatself.
Various combinations of dtype and format.

The current list below maybe not completed. If necessary, please add it.
"""


Loading…
Cancel
Save