Browse Source

!643 Fix code check

Merge pull request !643 from TangQunzhang/ge_dev
pull/644/head
TangQunzhang wqtshg 3 years ago
parent
commit
bb94ff254d
5 changed files with 29 additions and 84 deletions
  1. +29
    -28
      parser/tensorflow/tensorflow_parser.cc
  2. +0
    -16
      parser/tensorflow/tensorflow_ref_switch_parser.cc
  3. +0
    -20
      parser/tensorflow/tensorflow_ref_switch_parser.h
  4. +0
    -17
      parser/tensorflow/tensorflow_shape_n_parser.cc
  5. +0
    -3
      parser/tensorflow/tensorflow_shape_n_parser.h

+ 29
- 28
parser/tensorflow/tensorflow_parser.cc View File

@@ -915,8 +915,9 @@ Status TensorFlowModelParser::CheckOpType(const domi::tensorflow::NodeDef *node_
GE_CHK_STATUS_RET(CheckOpShapeDim(node_def, check_dims[op_type], valid), "failed to check op shape");
GE_IF_BOOL_EXEC(!valid, op_type = ge::parser::FRAMEWORKOP;
GELOGI("Set op %s to frameworkop", node_name.c_str());
framework_ops_[node_name] = node_def;);
);
framework_ops_[node_name] = node_def;
);
);

GE_IF_BOOL_EXEC(
op_type == ge::parser::ADD || op_type == ge::parser::MULTIPLY || op_type == ge::parser::MEAN,
@@ -1777,8 +1778,8 @@ bool TensorFlowModelParser::MaybeFusionOp(shared_ptr<ge::ScopeGraph> &scope_grap
std::vector<ge::ScopeFusionOpInfo> info_list;
auto &impl = scope_graph->impl_;
if (impl->IsFusionOpChild(node_def->name(), info_list)) {
GE_IF_BOOL_EXEC(
info_list.size() > 0, for (size_t i = 0; i < info_list.size(); ++i) {
GE_IF_BOOL_EXEC(info_list.size() > 0,
for (size_t i = 0; i < info_list.size(); ++i) {
fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].fusion_op_type);
fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].description);
fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(node_def);
@@ -3480,35 +3481,35 @@ void TensorFlowModelParser::RemoveInputAttr(domi::tensorflow::NodeDef *node_def,
attr_map->find(ge::ATTR_NAME_INPUT_TENSOR_DESC);
if (it == attr_map->end()) {
GELOGW("Failed to find input desc from tf node_def[%s]", node_def->name().c_str());
} else {
domi::tensorflow::AttrValue *input_attr_value = &(it->second);
auto tmp_attr = input_attr_value->mutable_list()->mutable_func();
auto attr_it = tmp_attr->begin();
int index = 0;
for (auto input_it = inputs->begin(); input_it != inputs->end(); ++input_it, ++index) {
// 1.decide whether to remove the input
bool flag = false;
for (auto &remove_input : remove_inputs_map) {
string remove_input_name = remove_input.first;
vector<int> remove_input_indexs = remove_input.second;
if ((*input_it) == remove_input_name &&
std::find(remove_input_indexs.begin(), remove_input_indexs.end(), index) != remove_input_indexs.end()) {
GELOGD("Remove input attr:%s, index:%d", remove_input_name.c_str(), index);
flag = true;
break;
}
return;
}
domi::tensorflow::AttrValue *input_attr_value = &(it->second);
auto tmp_attr = input_attr_value->mutable_list()->mutable_func();
auto attr_it = tmp_attr->begin();
int index = 0;
for (auto input_it = inputs->begin(); input_it != inputs->end(); ++input_it, ++index) {
// 1.decide whether to remove the input
bool flag = false;
for (auto &remove_input : remove_inputs_map) {
string remove_input_name = remove_input.first;
vector<int> remove_input_indexs = remove_input.second;
if ((*input_it) == remove_input_name &&
std::find(remove_input_indexs.begin(), remove_input_indexs.end(), index) != remove_input_indexs.end()) {
GELOGD("Remove input attr:%s, index:%d", remove_input_name.c_str(), index);
flag = true;
break;
}
}

if (flag) {
// 2.1 remove the input attr
if (!tmp_attr->empty() && (attr_it != tmp_attr->end())) {
attr_it = tmp_attr->erase(attr_it);
} else {
++attr_it;
}
if (flag) {
// 2.1 remove the input attr
if (!tmp_attr->empty() && (attr_it != tmp_attr->end())) {
attr_it = tmp_attr->erase(attr_it);
} else {
++attr_it;
}
} else {
++attr_it;
}
}
}


+ 0
- 16
parser/tensorflow/tensorflow_ref_switch_parser.cc View File

@@ -64,29 +64,13 @@ Status TensorFlowRefSwitchParser::ParseParams(const Message *op_src, ge::OpDescP
op.Name(node->name());

GELOGI("RefSwitch Op %s ParseParams Begin.", node->name().c_str());
GE_RETURN_IF_ERROR(PreParseParams(node, &op));

GE_RETURN_WITH_LOG_IF_ERROR(ParseT(node, &op), "Parse T for node %s failed.", node->name().c_str());

GE_RETURN_IF_ERROR(PostParseParams(node, &op));

Status status = ConvertToOpDesc(op, op_dest);

return status;
}

// AUTO GEN PLEASE DO NOT MODIFY IT
Status TensorFlowRefSwitchParser::PreParseParams(const domi::tensorflow::NodeDef *node, RefSwitchOperator *op) {
(void)node;
(void)op;
return SUCCESS;
}

Status TensorFlowRefSwitchParser::PostParseParams(const domi::tensorflow::NodeDef *node, RefSwitchOperator *op) {
(void)node;
(void)op;
return SUCCESS;
}

REGISTER_OP_PARSER_CREATOR(TENSORFLOW, REFSWITCH, TensorFlowRefSwitchParser);
} // namespace ge

+ 0
- 20
parser/tensorflow/tensorflow_ref_switch_parser.h View File

@@ -35,26 +35,6 @@ class PARSER_FUNC_VISIBILITY TensorFlowRefSwitchParser : public TensorFlowOpPars
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override;

protected:
/**
* @ingroup domi_omg
* @brief 解析模型文件信息
* @param [in] v_input_const 待解析的模型数据
* @param [out] node 解析后的模型数据
* @return SUCCESS 解析成功
* @return FAILED 解析失败
*/
Status PreParseParams(const domi::tensorflow::NodeDef *node, RefSwitchOperator *op);

/**
* @ingroup domi_omg
* @brief 解析模型文件信息
* @param [in] v_input_const 待解析的模型数据
* @param [out] node 解析后的模型数据
* @return SUCCESS 解析成功
* @return FAILED 解析失败
*/
Status PostParseParams(const domi::tensorflow::NodeDef *node, RefSwitchOperator *op);

/**
* @ingroup domi_omg
* @brief 解析模型文件信息


+ 0
- 17
parser/tensorflow/tensorflow_shape_n_parser.cc View File

@@ -100,16 +100,12 @@ Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr
ShapeNOperator op;
op.Name(node->name());

GE_RETURN_IF_ERROR(PreParseParams(node, &op));

GE_RETURN_WITH_LOG_IF_ERROR(ParseInType(node, &op), "Parse in type for node %s failed.", node->name().c_str());

GE_RETURN_WITH_LOG_IF_ERROR(ParseN(node, &op), "Parse N for node %s failed.", node->name().c_str());

GE_RETURN_WITH_LOG_IF_ERROR(ParseOutType(node, &op), "Parse out type for node %s failed.", node->name().c_str());

GE_RETURN_IF_ERROR(PostParseParams(node, &op));

// add dynamic input/output
domi::tensorflow::AttrValue attr_num;
CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, SHAPEN_ATTR_N, attr_num),
@@ -154,18 +150,5 @@ Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr
return SUCCESS;
}

// AUTO GEN PLEASE DO NOT MODIFY IT
Status TensorFlowShapeNParser::PreParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op) {
(void)node;
(void)op;
return SUCCESS;
}

Status TensorFlowShapeNParser::PostParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op) {
(void)node;
(void)op;
return SUCCESS;
}

REGISTER_OP_PARSER_CREATOR(TENSORFLOW, SHAPEN, TensorFlowShapeNParser);
} // namespace ge

+ 0
- 3
parser/tensorflow/tensorflow_shape_n_parser.h View File

@@ -27,9 +27,6 @@ class PARSER_FUNC_VISIBILITY TensorFlowShapeNParser : public TensorFlowOpParser
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override;

protected:
Status PreParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op);
Status PostParseParams(const domi::tensorflow::NodeDef *node, const ShapeNOperator *op);

static Status ParseN(const domi::tensorflow::NodeDef *node, ShapeNOperator *op);
static Status ParseInType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op);
static Status ParseOutType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op);


Loading…
Cancel
Save