|
|
|
@@ -17,9 +17,14 @@ |
|
|
|
#include "tools/converter/adapter/acl/mapper/arithmetic_mapper.h" |
|
|
|
#include <memory> |
|
|
|
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" |
|
|
|
#include "src/common/log_util.h" |
|
|
|
#include "ops/real_div.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace lite { |
|
|
|
static const std::map<std::string, PrimitivePtr> kDivTypeMap = {{"Div", std::make_shared<ops::Div>()}, |
|
|
|
{"RealDiv", std::make_shared<ops::RealDiv>()}}; |
|
|
|
|
|
|
|
STATUS AddFusionMapper::Mapper(const CNodePtr &cnode) { |
|
|
|
auto dst_prim = std::make_shared<ops::Add>(); |
|
|
|
if (MoveAttrMap(cnode, dst_prim) != RET_OK) { |
|
|
|
@@ -30,11 +35,25 @@ STATUS AddFusionMapper::Mapper(const CNodePtr &cnode) { |
|
|
|
} |
|
|
|
|
|
|
|
STATUS DivFusionMapper::Mapper(const CNodePtr &cnode) { |
|
|
|
auto dst_prim = std::make_shared<ops::Div>(); |
|
|
|
if (MoveAttrMap(cnode, dst_prim) != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "DivFusion mapper failed."; |
|
|
|
return RET_ERROR; |
|
|
|
ValueNodePtr value_node = nullptr; |
|
|
|
PrimitivePtr src_prim = nullptr; |
|
|
|
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Get primitive from cnode failed."; |
|
|
|
return lite::RET_ERROR; |
|
|
|
} |
|
|
|
std::string original_name = "Div"; |
|
|
|
auto name_ptr = src_prim->GetAttr(ops::kOriginalOpName); |
|
|
|
if (name_ptr != nullptr) { |
|
|
|
original_name = GetValue<std::string>(name_ptr); |
|
|
|
original_name = original_name.empty() ? "Div" : original_name; |
|
|
|
} |
|
|
|
PrimitivePtr dst_prim = nullptr; |
|
|
|
if (kDivTypeMap.find(original_name) != kDivTypeMap.end()) { |
|
|
|
dst_prim = kDivTypeMap.at(original_name); |
|
|
|
} |
|
|
|
CHECK_NULL_RETURN(dst_prim); |
|
|
|
dst_prim->SetAttrs(src_prim->attrs()); |
|
|
|
value_node->set_value(dst_prim); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
|