Browse Source

st div op and parameter format for ocr precision

tags/v1.6.0
zhengyuanhua 4 years ago
parent
commit
742e6b61b4
5 changed files with 47 additions and 5 deletions
  1. +2
    -0
      mindspore/ccsrc/transform/graph_ir/convert.cc
  2. +19
    -0
      mindspore/core/utils/check_convert_utils.cc
  3. +1
    -0
      mindspore/core/utils/check_convert_utils.h
  4. +23
    -4
      mindspore/lite/tools/converter/adapter/acl/mapper/arithmetic_mapper.cc
  5. +2
    -1
      mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc

+ 2
- 0
mindspore/ccsrc/transform/graph_ir/convert.cc View File

@@ -1877,6 +1877,8 @@ void DfGraphConvertor::SaveParamFormat(const CNodePtr node) {
bool converted = CheckAndConvertUtils::ConvertAttrValueToString(prim->name(), "format", &attr.second);
if (converted) {
format = attr.second->ToString();
} else {
CheckAndConvertUtils::GetFormatStringVal(prim, &format);
}
}
if (format != "NCDHW" && format != "NHWC") {


+ 19
- 0
mindspore/core/utils/check_convert_utils.cc View File

@@ -271,6 +271,25 @@ bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type,
return true;
}

void CheckAndConvertUtils::GetFormatStringVal(const PrimitivePtr &prim, std::string *format) {
if (prim == nullptr || format == nullptr) {
MS_LOG(DEBUG) << "Prim or format is nullptr.";
return;
}
auto value_ptr = prim->GetAttr(ops::kFormat);
if (value_ptr == nullptr) {
MS_LOG(DEBUG) << "Val is nullptr! op type = " << prim->name();
return;
}
int64_t data_format;
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value_ptr, &data_format);
if (result) {
if (DataFormatToStrMap.find(data_format) != DataFormatToStrMap.end()) {
*format = DataFormatToStrMap.at(data_format);
}
}
}

void CheckAndConvertUtils::ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name,
ValuePtr *const value) {
if (value == nullptr || *value == nullptr) {


+ 1
- 0
mindspore/core/utils/check_convert_utils.h View File

@@ -316,6 +316,7 @@ class CheckAndConvertUtils {
static void CheckInputArgs(const std::vector<AbstractBasePtr> &input_args, const CompareEnum compare_operator,
const int64_t match_value, const std::string &prim_name);
static bool HasDynamicShapeInput(const AbstractBasePtrList &abs_list);
static void GetFormatStringVal(const PrimitivePtr &prim, std::string *format);

private:
static TypePtr _CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name,


+ 23
- 4
mindspore/lite/tools/converter/adapter/acl/mapper/arithmetic_mapper.cc View File

@@ -17,9 +17,14 @@
#include "tools/converter/adapter/acl/mapper/arithmetic_mapper.h"
#include <memory>
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
#include "src/common/log_util.h"
#include "ops/real_div.h"

namespace mindspore {
namespace lite {
static const std::map<std::string, PrimitivePtr> kDivTypeMap = {{"Div", std::make_shared<ops::Div>()},
{"RealDiv", std::make_shared<ops::RealDiv>()}};

STATUS AddFusionMapper::Mapper(const CNodePtr &cnode) {
auto dst_prim = std::make_shared<ops::Add>();
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
@@ -30,11 +35,25 @@ STATUS AddFusionMapper::Mapper(const CNodePtr &cnode) {
}

STATUS DivFusionMapper::Mapper(const CNodePtr &cnode) {
auto dst_prim = std::make_shared<ops::Div>();
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
MS_LOG(ERROR) << "DivFusion mapper failed.";
return RET_ERROR;
ValueNodePtr value_node = nullptr;
PrimitivePtr src_prim = nullptr;
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Get primitive from cnode failed.";
return lite::RET_ERROR;
}
std::string original_name = "Div";
auto name_ptr = src_prim->GetAttr(ops::kOriginalOpName);
if (name_ptr != nullptr) {
original_name = GetValue<std::string>(name_ptr);
original_name = original_name.empty() ? "Div" : original_name;
}
PrimitivePtr dst_prim = nullptr;
if (kDivTypeMap.find(original_name) != kDivTypeMap.end()) {
dst_prim = kDivTypeMap.at(original_name);
}
CHECK_NULL_RETURN(dst_prim);
dst_prim->SetAttrs(src_prim->attrs());
value_node->set_value(dst_prim);
return RET_OK;
}



+ 2
- 1
mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc View File

@@ -101,7 +101,8 @@ ops::PrimitiveC *TFDivParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "add op input failed";
return nullptr;
}

std::string original_name = tf_op.op();
prim->AddAttr(ops::kOriginalOpName, MakeValue(original_name));
return prim.release();
}



Loading…
Cancel
Save