Browse Source

!208 parser mem_type for variable

Merge pull request !208 from 陈叶朦/development
pull/208/MERGE
i-robot Gitee 5 years ago
parent
commit
49e21ec4f8
4 changed files with 21 additions and 0 deletions
  1. +5
    -0
      parser/common/op_def/variable_op.cc
  2. +2
    -0
      parser/common/op_def/variable_op.h
  3. +2
    -0
      parser/common/parser_types.cc
  4. +12
    -0
      parser/tensorflow/tensorflow_variable_v2_parser.cc

+ 5
- 0
parser/common/op_def/variable_op.cc View File

@@ -43,6 +43,11 @@ VariableOperator &VariableOperator::Placement(const std::string &placement) {
return *this;
}

VariableOperator &VariableOperator::MemType(const uint32_t &mem_type) {
Attr(ATTR_OUTPUT_MEMORY_TYPE, mem_type);
return *this;
}

VariableOperator &VariableOperator::SrcType(const int64_t &dtype) {
Attr(VAR_ATTR_DTYPE, dtype);
return *this;


+ 2
- 0
parser/common/op_def/variable_op.h View File

@@ -35,6 +35,8 @@ class VariableOperator : public ParserOperator {

VariableOperator &Placement(const std::string &placement);

VariableOperator &MemType(const uint32_t &mem_type);

VariableOperator &SrcType(const int64_t &dtype);

VariableOperator &VarShape(const std::vector<int64_t> &shape_value);


+ 2
- 0
parser/common/parser_types.cc View File

@@ -347,7 +347,9 @@ const char *HCOMREDUCESCATTER = "HcomReduceScatter";
const char *HCOMSEND = "HcomSend";
const char *HCOMRECEIVE = "HcomReceive";
const char *HCOMREMOTEREAD = "HcomRemoteRead";
const char *HCOMREMOTEREFREAD = "HcomRemoteRefRead";
const char *HCOMREMOTEWRITE = "HcomRemoteWrite";
const char *HCOMREMOTESCATTERWRITE = "HcomRemoteScatterWrite";

const char *VARASSIGN = "VarAssign";
const char *VARISINITIALIZEDOP = "VarIsInitializedOp";


+ 12
- 0
parser/tensorflow/tensorflow_variable_v2_parser.cc View File

@@ -225,6 +225,17 @@ static void ParsePlacement(const domi::tensorflow::NodeDef *node, VariableOperat
}
}

static void ParseMemType(const domi::tensorflow::NodeDef *node, VariableOperator *op) {
// The upper caller guarantees input params is not empty.
string node_src_name = node->name();
domi::tensorflow::AttrValue attr_value;
GELOGI("Start to parse mem_type, %s", node_src_name.c_str());
if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_OUTPUT_MEMORY_TYPE, attr_value)) {
uint32_t mem_type = attr_value.i();
op->MemType(mem_type);
}
}

Status ParseParams(const Message *op_src, VariableOperator *op) {
GE_CHECK_NOTNULL(op_src);
const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src);
@@ -241,6 +252,7 @@ Status ParseParams(const Message *op_src, VariableOperator *op) {
GE_RETURN_IF_ERROR(ParseSrcType(node, op));
GE_RETURN_IF_ERROR(ParseVarShape(node, op));
ParsePlacement(node, op);
ParseMemType(node, op);

GELOGD("VariabeV2 OP parser params success.op name : %s.", node->name().c_str());



Loading…
Cancel
Save