diff --git a/parser/common/op_def/variable_op.cc b/parser/common/op_def/variable_op.cc index 2cf294e..6000a2e 100644 --- a/parser/common/op_def/variable_op.cc +++ b/parser/common/op_def/variable_op.cc @@ -43,6 +43,11 @@ VariableOperator &VariableOperator::Placement(const std::string &placement) { return *this; } +VariableOperator &VariableOperator::MemType(const uint32_t &mem_type) { + Attr(ATTR_OUTPUT_MEMORY_TYPE, mem_type); + return *this; +} + VariableOperator &VariableOperator::SrcType(const int64_t &dtype) { Attr(VAR_ATTR_DTYPE, dtype); return *this; diff --git a/parser/common/op_def/variable_op.h b/parser/common/op_def/variable_op.h index c9b85d3..166681e 100644 --- a/parser/common/op_def/variable_op.h +++ b/parser/common/op_def/variable_op.h @@ -35,6 +35,8 @@ class VariableOperator : public ParserOperator { VariableOperator &Placement(const std::string &placement); + VariableOperator &MemType(const uint32_t &mem_type); + VariableOperator &SrcType(const int64_t &dtype); VariableOperator &VarShape(const std::vector &shape_value); diff --git a/parser/common/parser_types.cc b/parser/common/parser_types.cc index 440e884..b53d37a 100644 --- a/parser/common/parser_types.cc +++ b/parser/common/parser_types.cc @@ -347,7 +347,9 @@ const char *HCOMREDUCESCATTER = "HcomReduceScatter"; const char *HCOMSEND = "HcomSend"; const char *HCOMRECEIVE = "HcomReceive"; const char *HCOMREMOTEREAD = "HcomRemoteRead"; +const char *HCOMREMOTEREFREAD = "HcomRemoteRefRead"; const char *HCOMREMOTEWRITE = "HcomRemoteWrite"; +const char *HCOMREMOTESCATTERWRITE = "HcomRemoteScatterWrite"; const char *VARASSIGN = "VarAssign"; const char *VARISINITIALIZEDOP = "VarIsInitializedOp"; diff --git a/parser/tensorflow/tensorflow_variable_v2_parser.cc b/parser/tensorflow/tensorflow_variable_v2_parser.cc index 139dd0e..c8b8a98 100644 --- a/parser/tensorflow/tensorflow_variable_v2_parser.cc +++ b/parser/tensorflow/tensorflow_variable_v2_parser.cc @@ -225,6 +225,17 @@ static void ParsePlacement(const domi::tensorflow::NodeDef *node, VariableOperat } } +static void ParseMemType(const domi::tensorflow::NodeDef *node, VariableOperator *op) { + // The upper caller guarantees input params is not empty. + string node_src_name = node->name(); + domi::tensorflow::AttrValue attr_value; + GELOGI("Start to parse mem_type, %s", node_src_name.c_str()); + if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_OUTPUT_MEMORY_TYPE, attr_value)) { + uint32_t mem_type = attr_value.i(); + op->MemType(mem_type); + } +} + Status ParseParams(const Message *op_src, VariableOperator *op) { GE_CHECK_NOTNULL(op_src); const NodeDef *node = reinterpret_cast(op_src); @@ -241,6 +252,7 @@ Status ParseParams(const Message *op_src, VariableOperator *op) { GE_RETURN_IF_ERROR(ParseSrcType(node, op)); GE_RETURN_IF_ERROR(ParseVarShape(node, op)); ParsePlacement(node, op); + ParseMemType(node, op); GELOGD("VariabeV2 OP parser params success.op name : %s.", node->name().c_str());