Merge pull request !208 from 陈叶朦/developmentpull/208/MERGE
| @@ -43,6 +43,11 @@ VariableOperator &VariableOperator::Placement(const std::string &placement) { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| VariableOperator &VariableOperator::MemType(const uint32_t &mem_type) { | |||||
| Attr(ATTR_OUTPUT_MEMORY_TYPE, mem_type); | |||||
| return *this; | |||||
| } | |||||
| VariableOperator &VariableOperator::SrcType(const int64_t &dtype) { | VariableOperator &VariableOperator::SrcType(const int64_t &dtype) { | ||||
| Attr(VAR_ATTR_DTYPE, dtype); | Attr(VAR_ATTR_DTYPE, dtype); | ||||
| return *this; | return *this; | ||||
| @@ -35,6 +35,8 @@ class VariableOperator : public ParserOperator { | |||||
| VariableOperator &Placement(const std::string &placement); | VariableOperator &Placement(const std::string &placement); | ||||
| VariableOperator &MemType(const uint32_t &mem_type); | |||||
| VariableOperator &SrcType(const int64_t &dtype); | VariableOperator &SrcType(const int64_t &dtype); | ||||
| VariableOperator &VarShape(const std::vector<int64_t> &shape_value); | VariableOperator &VarShape(const std::vector<int64_t> &shape_value); | ||||
| @@ -347,7 +347,9 @@ const char *HCOMREDUCESCATTER = "HcomReduceScatter"; | |||||
| const char *HCOMSEND = "HcomSend"; | const char *HCOMSEND = "HcomSend"; | ||||
| const char *HCOMRECEIVE = "HcomReceive"; | const char *HCOMRECEIVE = "HcomReceive"; | ||||
| const char *HCOMREMOTEREAD = "HcomRemoteRead"; | const char *HCOMREMOTEREAD = "HcomRemoteRead"; | ||||
| const char *HCOMREMOTEREFREAD = "HcomRemoteRefRead"; | |||||
| const char *HCOMREMOTEWRITE = "HcomRemoteWrite"; | const char *HCOMREMOTEWRITE = "HcomRemoteWrite"; | ||||
| const char *HCOMREMOTESCATTERWRITE = "HcomRemoteScatterWrite"; | |||||
| const char *VARASSIGN = "VarAssign"; | const char *VARASSIGN = "VarAssign"; | ||||
| const char *VARISINITIALIZEDOP = "VarIsInitializedOp"; | const char *VARISINITIALIZEDOP = "VarIsInitializedOp"; | ||||
| @@ -225,6 +225,17 @@ static void ParsePlacement(const domi::tensorflow::NodeDef *node, VariableOperat | |||||
| } | } | ||||
| } | } | ||||
| static void ParseMemType(const domi::tensorflow::NodeDef *node, VariableOperator *op) { | |||||
| // The upper caller guarantees input params is not empty. | |||||
| string node_src_name = node->name(); | |||||
| domi::tensorflow::AttrValue attr_value; | |||||
| GELOGI("Start to parse mem_type, %s", node_src_name.c_str()); | |||||
| if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_OUTPUT_MEMORY_TYPE, attr_value)) { | |||||
| uint32_t mem_type = attr_value.i(); | |||||
| op->MemType(mem_type); | |||||
| } | |||||
| } | |||||
| Status ParseParams(const Message *op_src, VariableOperator *op) { | Status ParseParams(const Message *op_src, VariableOperator *op) { | ||||
| GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
| const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | ||||
| @@ -241,6 +252,7 @@ Status ParseParams(const Message *op_src, VariableOperator *op) { | |||||
| GE_RETURN_IF_ERROR(ParseSrcType(node, op)); | GE_RETURN_IF_ERROR(ParseSrcType(node, op)); | ||||
| GE_RETURN_IF_ERROR(ParseVarShape(node, op)); | GE_RETURN_IF_ERROR(ParseVarShape(node, op)); | ||||
| ParsePlacement(node, op); | ParsePlacement(node, op); | ||||
| ParseMemType(node, op); | |||||
| GELOGD("VariabeV2 OP parser params success.op name : %s.", node->name().c_str()); | GELOGD("VariabeV2 OP parser params success.op name : %s.", node->name().c_str()); | ||||