GitOrigin-RevId: d295abb5da
tags/v1.3.0
| @@ -305,6 +305,7 @@ def dump_graph( | |||||
| output_vars: Union[Dict[str, VarNode], List[VarNode]], | output_vars: Union[Dict[str, VarNode], List[VarNode]], | ||||
| *, | *, | ||||
| keep_var_name: int = 1, | keep_var_name: int = 1, | ||||
| keep_op_name: bool = True, | |||||
| keep_param_name: bool = False, | keep_param_name: bool = False, | ||||
| keep_opr_priority: bool = False, | keep_opr_priority: bool = False, | ||||
| strip_info_file=None, | strip_info_file=None, | ||||
| @@ -325,6 +326,7 @@ def dump_graph( | |||||
| * 0: none of the names are kept | * 0: none of the names are kept | ||||
| * 1: (default)keep names of output vars | * 1: (default)keep names of output vars | ||||
| * 2: keep names of all (output and internal) vars | * 2: keep names of all (output and internal) vars | ||||
| :param keep_op_name: whether to keep operator names. | |||||
| :param keep_param_name: whether to keep param names, so param values can be | :param keep_param_name: whether to keep param names, so param values can be | ||||
| easily manipulated after loading model | easily manipulated after loading model | ||||
| :param keep_opr_priority: whether to keep priority setting for operators | :param keep_opr_priority: whether to keep priority setting for operators | ||||
| @@ -368,6 +370,7 @@ def dump_graph( | |||||
| dump_content = _imperative_rt.dump_graph( | dump_content = _imperative_rt.dump_graph( | ||||
| ov, | ov, | ||||
| keep_var_name, | keep_var_name, | ||||
| keep_op_name, | |||||
| keep_param_name, | keep_param_name, | ||||
| keep_opr_priority, | keep_opr_priority, | ||||
| stat, | stat, | ||||
| @@ -294,6 +294,7 @@ void init_graph_rt(py::module m) { | |||||
| m.def("dump_graph", []( | m.def("dump_graph", []( | ||||
| const std::vector<VarNode*>& dest_vars, | const std::vector<VarNode*>& dest_vars, | ||||
| int keep_var_name, | int keep_var_name, | ||||
| bool keep_op_name, | |||||
| bool keep_param_name, | bool keep_param_name, | ||||
| bool keep_opr_priority, | bool keep_opr_priority, | ||||
| py::list& stat, | py::list& stat, | ||||
| @@ -306,7 +307,7 @@ void init_graph_rt(py::module m) { | |||||
| SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | ||||
| ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, | ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, | ||||
| keep_opr_priority}; | |||||
| keep_opr_priority, keep_op_name}; | |||||
| auto rst = dumper->dump(symvars, config); | auto rst = dumper->dump(symvars, config); | ||||
| for (auto i : rst.inputs) { | for (auto i : rst.inputs) { | ||||
| @@ -124,6 +124,7 @@ table Operator { | |||||
| blobs:[Blob]; | blobs:[Blob]; | ||||
| /// Operator may want to save more than one OperatorParam | /// Operator may want to save more than one OperatorParam | ||||
| additional_params:[OperatorParam]; | additional_params:[OperatorParam]; | ||||
| name:string; | |||||
| } | } | ||||
| struct OutputVar { | struct OutputVar { | ||||
| @@ -208,6 +208,11 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( | |||||
| inputs = m_builder.CreateVector(v); | inputs = m_builder.CreateVector(v); | ||||
| } | } | ||||
| Offset<String> operator_name; | |||||
| if (m_config.keep_op_name) { | |||||
| operator_name = m_builder.CreateSharedString(opr->name()); | |||||
| } | |||||
| Offset<Vector<Offset<String>>> output_names; | Offset<Vector<Offset<String>>> output_names; | ||||
| if (m_config.keep_var_name >= 2 || | if (m_config.keep_var_name >= 2 || | ||||
| (m_config.keep_var_name == 1 && | (m_config.keep_var_name == 1 && | ||||
| @@ -255,6 +260,7 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( | |||||
| } | } | ||||
| builder.add_comp_node(comp_node); | builder.add_comp_node(comp_node); | ||||
| builder.add_output_name(output_names); | builder.add_output_name(output_names); | ||||
| builder.add_name(operator_name); | |||||
| builder.add_output_dtype(output_dtype); | builder.add_output_dtype(output_dtype); | ||||
| if (param_cnt > 0) { | if (param_cnt > 0) { | ||||
| builder.add_param_type(m_cur_opr_param_type[0]); | builder.add_param_type(m_cur_opr_param_type[0]); | ||||
| @@ -698,6 +704,9 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( | |||||
| if (fbopr->output_dtype()) { | if (fbopr->output_dtype()) { | ||||
| config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype())); | config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype())); | ||||
| } | } | ||||
| if (fbopr->name()) { | |||||
| config.name(fbopr->name()->str()); | |||||
| } | |||||
| if (fbopr->comp_node()) { | if (fbopr->comp_node()) { | ||||
| auto cnt = fbopr->comp_node()->size(); | auto cnt = fbopr->comp_node()->size(); | ||||
| cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt); | cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt); | ||||
| @@ -43,6 +43,9 @@ struct GraphDumpConfig { | |||||
| //! whether to keep operator priorities | //! whether to keep operator priorities | ||||
| bool keep_opr_priority; | bool keep_opr_priority; | ||||
| //! whether to keep operator names | |||||
| bool keep_op_name; | |||||
| //! extra user data to be passed by dump caller into opr dump | //! extra user data to be passed by dump caller into opr dump | ||||
| //! implementations; useful for implementing nested opr dump | //! implementations; useful for implementing nested opr dump | ||||
| std::shared_ptr<UserDataContainer> user_data; | std::shared_ptr<UserDataContainer> user_data; | ||||
| @@ -57,12 +60,14 @@ struct GraphDumpConfig { | |||||
| GraphDumpConfig(int keep_var_name_ = 1, bool keep_param_name_ = false, | GraphDumpConfig(int keep_var_name_ = 1, bool keep_param_name_ = false, | ||||
| bool keep_opr_priority_ = false, | bool keep_opr_priority_ = false, | ||||
| bool keep_op_name_ = true, | |||||
| const std::shared_ptr<UserDataContainer>& user_data_ = | const std::shared_ptr<UserDataContainer>& user_data_ = | ||||
| std::make_shared<UserDataContainer>(), | std::make_shared<UserDataContainer>(), | ||||
| const TensorValueDumper& tensor_value_dumper_ = {}) | const TensorValueDumper& tensor_value_dumper_ = {}) | ||||
| : keep_var_name{keep_var_name_}, | : keep_var_name{keep_var_name_}, | ||||
| keep_param_name{keep_param_name_}, | keep_param_name{keep_param_name_}, | ||||
| keep_opr_priority{keep_opr_priority_}, | keep_opr_priority{keep_opr_priority_}, | ||||
| keep_op_name{keep_op_name_}, | |||||
| user_data{user_data_}, | user_data{user_data_}, | ||||
| tensor_value_dumper{tensor_value_dumper_} {} | tensor_value_dumper{tensor_value_dumper_} {} | ||||
| }; | }; | ||||
| @@ -711,6 +711,39 @@ TEST(TestSerializer2, ParamerizedDType) { | |||||
| load(); | load(); | ||||
| } | } | ||||
| TEST(TestSerializer2, OperatorName) { | |||||
| auto fname = GET_OUTPUT_FILE(); | |||||
| TensorShape shape{2, 3}; | |||||
| auto dump = [&]() { | |||||
| auto cn = CompNode::load("xpu0"); | |||||
| auto host_x = std::make_shared<HostTensorND>(cn, shape), | |||||
| host_y = std::make_shared<HostTensorND>(cn, shape); | |||||
| auto graph = ComputingGraph::make(); | |||||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}), | |||||
| y = opr::Host2DeviceCopy::make(*graph, host_y, {"y"}); | |||||
| using Mode = opr::Elemwise::Mode; | |||||
| auto z = opr::Elemwise::make({x, y}, Mode::ADD, {"add(x, y)"}); | |||||
| auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()), | |||||
| GraphDumpFormat::FLATBUFFERS); | |||||
| auto rst = dumper->dump({z.rename("z")}); | |||||
| }; | |||||
| auto load = [&]() { | |||||
| HostTensorGenerator<> gen; | |||||
| auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()), | |||||
| GraphDumpFormat::FLATBUFFERS); | |||||
| auto rst = loader->load(); | |||||
| auto z = rst.output_var_map.at("z"); | |||||
| auto op_name = z.node()->owner_opr()->cname(); | |||||
| int cmp = strcmp(op_name, "add(x, y)"); | |||||
| EXPECT_EQ(cmp, 0); | |||||
| }; | |||||
| dump(); | |||||
| load(); | |||||
| } | |||||
| TEST(TestSerializer2, HasOutputDtype) { | TEST(TestSerializer2, HasOutputDtype) { | ||||
| auto fname = GET_OUTPUT_FILE(); | auto fname = GET_OUTPUT_FILE(); | ||||