| @@ -66,6 +66,10 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph | |||
| for (auto &node : manager->all_nodes()) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| const AbstractBasePtr &prev_inferred = node->abstract(); | |||
| // Keep previous inferred value for CNode if is loaded from MindIR. | |||
| if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) { | |||
| continue; | |||
| } | |||
| // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction. | |||
| if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) { | |||
| node->set_abstract(nullptr); | |||
| @@ -113,6 +117,69 @@ FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, | |||
| return ret; | |||
| } | |||
| const FuncGraphPtr GetLoadedGraph(const ResourcePtr &res) { | |||
| MS_EXCEPTION_IF_NULL(res); | |||
| auto manager = res->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| FuncGraphPtr loaded_graph = nullptr; | |||
| size_t loaded_graph_num = 0; | |||
| auto all_graphs = manager->func_graphs(); | |||
| for (auto &graph : all_graphs) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| if (graph->has_attr("is_load")) { | |||
| loaded_graph = graph; | |||
| loaded_graph_num += 1; | |||
| } | |||
| } | |||
| if (loaded_graph_num == 0) { | |||
| return nullptr; | |||
| } | |||
| if (loaded_graph_num == 1) { | |||
| return loaded_graph; | |||
| } | |||
| MS_LOG(EXCEPTION) << "The loaded sub graph currently should less than 2, but got " << loaded_graph_num; | |||
| } | |||
| void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &loaded_graph) { | |||
| MS_EXCEPTION_IF_NULL(res); | |||
| auto manager = res->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| FuncGraphPtr root_graph = *(manager->roots().begin()); | |||
| auto root_inputs = root_graph->get_inputs(); | |||
| auto loaded_inputs = loaded_graph->get_inputs(); | |||
| size_t root_inputs_num = root_inputs.size(); | |||
| size_t loaded_inputs_num = loaded_inputs.size(); | |||
| if (root_inputs_num != loaded_inputs_num) { | |||
| MS_LOG(EXCEPTION) << "The inputs number " << root_inputs_num << " not equal to the inputs number of loaded graph " | |||
| << loaded_inputs_num; | |||
| } | |||
| for (size_t index = 0; index < root_inputs_num; index++) { | |||
| auto root_input = root_inputs[index]; | |||
| auto loaded_input = loaded_inputs[index]; | |||
| auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(root_input->Shape()); | |||
| auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(loaded_input->Shape()); | |||
| auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast<Type>(root_input->Type()); | |||
| auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast<Type>(loaded_input->Type()); | |||
| MS_EXCEPTION_IF_NULL(root_shape); | |||
| MS_EXCEPTION_IF_NULL(loaded_shape); | |||
| MS_EXCEPTION_IF_NULL(root_type); | |||
| MS_EXCEPTION_IF_NULL(loaded_type); | |||
| if (root_shape->shape() != loaded_shape->shape()) { | |||
| MS_EXCEPTION(ValueError) << "The " << index | |||
| << " th input shape differ from loaded graph. Input shape: " << root_shape->ToString() | |||
| << ", input shape of loaded graph: " << loaded_shape->ToString(); | |||
| } | |||
| if (root_type->type_id() != loaded_type->type_id()) { | |||
| MS_EXCEPTION(TypeError) << "The " << std::to_string(index) | |||
| << " th input type differ from loaded graph. Input type: " << root_type->ToString() | |||
| << ", input type of loaded graph: " << loaded_type->ToString(); | |||
| } | |||
| } | |||
| } | |||
| bool ParseAction(const ResourcePtr &res) { | |||
| if (!res->input()) { | |||
| MS_LOG(EXCEPTION) << "Parse error"; | |||
| @@ -255,12 +322,14 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { | |||
| if (res->func_graph() == nullptr) { | |||
| MS_LOG(EXCEPTION) << "AbstractSpecialize error"; | |||
| } | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| abstract::AbstractBasePtrList args_spec = res->args_spec(); | |||
| auto context = parallel::ParallelContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance()); | |||
| context->ParallelParameterContextInitShape(func_graph); | |||
| // get original loaded graph to check inputs later | |||
| auto loaded_graph_ptr = GetLoadedGraph(res); | |||
| // suppose that there is not KeywordArgument for the top graph | |||
| // get the hyper parameter | |||
| for (const auto ¶m : func_graph->parameters()) { | |||
| @@ -294,7 +363,10 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { | |||
| } | |||
| } | |||
| } | |||
| // check input after abstract when there is a loaded graph | |||
| if (loaded_graph_ptr != nullptr) { | |||
| CheckRootInputShapeAndType(res, loaded_graph_ptr); | |||
| } | |||
| MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true); | |||
| return true; | |||
| } | |||
| @@ -111,6 +111,7 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| (void)m.def("init_pipeline", &mindspore::pipeline::InitPipeline, "Init Pipeline."); | |||
| (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph."); | |||
| (py::object) m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), "Load model as Graph."); | |||
| (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") | |||
| .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") | |||
| @@ -203,6 +203,19 @@ bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_ | |||
| return true; | |||
| } | |||
| bool ConvertFuncGraph(const py::object &obj, ValuePtr *const data) { | |||
| MS_LOG(DEBUG) << "Converting FuncGraph object"; | |||
| auto func_graph = obj.cast<FuncGraphPtr>(); | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Resolve FuncGraph error, get ptr is null"; | |||
| return false; | |||
| } | |||
| auto new_fg = BasicClone(func_graph); | |||
| new_fg->set_attr("is_load", MakeValue(true)); | |||
| *data = new_fg; | |||
| return true; | |||
| } | |||
| bool ConvertSlice(const py::object &obj, ValuePtr *const data) { | |||
| MS_LOG(DEBUG) << "Converting slice object"; | |||
| @@ -368,47 +381,21 @@ bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype | |||
| } | |||
| } // namespace | |||
| bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) { | |||
| // check parameter valid | |||
| if (data == nullptr) { | |||
| MS_LOG(ERROR) << "Data is null pointer"; | |||
| return false; | |||
| } | |||
| bool ret = true; | |||
| bool ConvertSingleData(const py::object &obj, ValuePtr *const data) { | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| ValuePtr converted = nullptr; | |||
| if (py::isinstance<py::none>(obj)) { | |||
| converted = kNone; | |||
| } else if (py::isinstance<py::bool_>(obj)) { | |||
| converted = std::make_shared<BoolImm>(py::cast<bool>(obj)); | |||
| } else if (py::isinstance<py::int_>(obj)) { | |||
| ret = ConvertIntegerWithType(py::cast<int64_t>(obj), &converted, dtype); | |||
| } else if (py::isinstance<py::float_>(obj)) { | |||
| ret = ConvertFloatWithType(py::cast<float>(obj), &converted, dtype); | |||
| } else if (py::isinstance<py::str>(obj)) { | |||
| converted = std::make_shared<StringImm>(py::cast<std::string>(obj)); | |||
| } else if (py::isinstance<py::dict>(obj)) { | |||
| ret = ConvertDict(obj, &converted, use_signature); | |||
| } else if (py::isinstance<py::slice>(obj)) { | |||
| ret = ConvertSlice(obj, &converted); | |||
| } else if (py::isinstance<py::ellipsis>(obj)) { | |||
| converted = kEllipsis; | |||
| } else if (py::isinstance<py::tuple>(obj)) { | |||
| ret = ConvertTuple(obj, &converted, use_signature); | |||
| } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { | |||
| ret = ConvertCellList(obj, &converted, use_signature); | |||
| } else if (py::isinstance<Cell>(obj)) { | |||
| return ConvertCellObjToFuncGraph(obj.cast<CellPtr>(), data); | |||
| } else if (py::isinstance<py::list>(obj)) { | |||
| ret = ConvertList(obj, &converted, use_signature); | |||
| } else if (py::isinstance<py::module>(obj)) { | |||
| ConvertNameSpace(obj, &converted); | |||
| } else if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) { | |||
| ConvertDataClass(obj, &converted); | |||
| } else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) { | |||
| ret = ConvertPrimitive(obj, &converted, use_signature); | |||
| } else if (py::isinstance<MetaFuncGraph>(obj)) { | |||
| ret = ConvertMetaFuncGraph(obj, &converted, use_signature); | |||
| } else if (py::isinstance<Type>(obj)) { | |||
| converted = obj.cast<TypePtr>(); | |||
| } else if (py::isinstance<Tensor>(obj)) { | |||
| @@ -425,9 +412,50 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature | |||
| } else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) { | |||
| converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj); | |||
| } else { | |||
| ret = ConvertOtherObj(obj, &converted); | |||
| return false; | |||
| } | |||
| *data = converted; | |||
| return true; | |||
| } | |||
| bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) { | |||
| // check parameter valid | |||
| if (data == nullptr) { | |||
| MS_LOG(ERROR) << "Data is null pointer"; | |||
| return false; | |||
| } | |||
| ValuePtr converted = nullptr; | |||
| bool ret = ConvertSingleData(obj, &converted); | |||
| if (ret) { | |||
| *data = converted; | |||
| return true; | |||
| } | |||
| if (py::isinstance<py::int_>(obj)) { | |||
| ret = ConvertIntegerWithType(py::cast<int64_t>(obj), &converted, dtype); | |||
| } else if (py::isinstance<py::float_>(obj)) { | |||
| ret = ConvertFloatWithType(py::cast<float>(obj), &converted, dtype); | |||
| } else if (py::isinstance<py::dict>(obj)) { | |||
| ret = ConvertDict(obj, &converted, use_signature); | |||
| } else if (py::isinstance<py::slice>(obj)) { | |||
| ret = ConvertSlice(obj, &converted); | |||
| } else if (py::isinstance<py::tuple>(obj)) { | |||
| ret = ConvertTuple(obj, &converted, use_signature); | |||
| } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { | |||
| ret = ConvertCellList(obj, &converted, use_signature); | |||
| } else if (py::isinstance<Cell>(obj)) { | |||
| return ConvertCellObjToFuncGraph(obj.cast<CellPtr>(), data); | |||
| } else if (py::isinstance<py::list>(obj)) { | |||
| ret = ConvertList(obj, &converted, use_signature); | |||
| } else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) { | |||
| ret = ConvertPrimitive(obj, &converted, use_signature); | |||
| } else if (py::isinstance<MetaFuncGraph>(obj)) { | |||
| ret = ConvertMetaFuncGraph(obj, &converted, use_signature); | |||
| } else if (py::isinstance<FuncGraph>(obj)) { | |||
| ret = ConvertFuncGraph(obj, &converted); | |||
| } else { | |||
| ret = ConvertOtherObj(obj, &converted); | |||
| } | |||
| *data = converted; | |||
| return ret; | |||
| } | |||
| @@ -113,6 +113,49 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object | |||
| return para_node; | |||
| } | |||
| void BroadenCNodeAbstract(const FuncGraphPtr &func_graph) { | |||
| std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); | |||
| for (const AnfNodePtr &node : nodes) { | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto abstract = node->abstract(); | |||
| if (abstract != nullptr) { | |||
| node->set_abstract(abstract->Broaden()); | |||
| } | |||
| } | |||
| } | |||
| void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) { | |||
| if (!value->isa<FuncGraph>()) { | |||
| return; | |||
| } | |||
| auto resolved_graph = value->cast<FuncGraphPtr>(); | |||
| MS_EXCEPTION_IF_NULL(resolved_graph); | |||
| if (!resolved_graph->has_attr("is_load")) { | |||
| return; | |||
| } | |||
| auto top_graph = Parser::GetTopFuncGraph(); | |||
| std::vector<AnfNodePtr> input_params; | |||
| for (auto const ¶m : resolved_graph->parameters()) { | |||
| auto param_ptr = dyn_cast<Parameter>(param); | |||
| MS_EXCEPTION_IF_NULL(param_ptr); | |||
| if (param_ptr->has_default()) { | |||
| param_ptr->set_func_graph(top_graph); | |||
| func_graph->add_used_global_parameters(param_ptr); | |||
| // update top_graph | |||
| top_graph->add_parameter(param_ptr); | |||
| size_t hyper_param_count = top_graph->hyper_param_count(); | |||
| top_graph->set_hyper_param_count(hyper_param_count + 1); | |||
| } else { | |||
| input_params.push_back(param_ptr); | |||
| } | |||
| } | |||
| resolved_graph->set_parameters(input_params); | |||
| BroadenCNodeAbstract(resolved_graph); | |||
| } | |||
| bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) { | |||
| AnfNodePtr output = nullptr; | |||
| if (py::hasattr(obj, "__parameter__") && py::isinstance<tensor::MetaTensor>(obj)) { | |||
| @@ -146,6 +189,7 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, | |||
| return false; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(convert_result); | |||
| ConvertLoadedGraph(func_graph, convert_result); | |||
| output = NewValueNode(convert_result); | |||
| if (convert_result->isa<tensor::Tensor>()) { | |||
| output = GetMixedPrecisionCastHelp(func_graph, output); | |||
| @@ -48,6 +48,7 @@ | |||
| #include "pybind_api/pybind_patch.h" | |||
| #include "utils/shape_utils.h" | |||
| #include "utils/info.h" | |||
| #include "load_mindir/load_model.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/constants.h" | |||
| #include "ps/util.h" | |||
| @@ -1096,6 +1097,8 @@ void ExportGraph(const std::string &file_name, const std::string &, const std::s | |||
| #endif | |||
| } | |||
| FuncGraphPtr LoadMindIR(const std::string &file_name) { return mindspore::LoadMindIR(file_name); } | |||
| void ReleaseGeTsd() { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| if (context_ptr != nullptr) { | |||
| @@ -140,6 +140,7 @@ void ClearResAtexit(); | |||
| void ReleaseGeTsd(); | |||
| void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase); | |||
| FuncGraphPtr LoadMindIR(const std::string &file_name); | |||
| // init and exec dataset sub graph | |||
| bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, | |||
| @@ -51,6 +51,9 @@ void ValidateOperation(const AnfNodePtr &node) { | |||
| if (abstract::IsInWhiteList(prim)) { | |||
| return; | |||
| } | |||
| if (prim->HasAttr("is_load")) { | |||
| return; | |||
| } | |||
| if (prim->HasPyEvaluator()) { | |||
| MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator."; | |||
| return; | |||
| @@ -273,6 +273,9 @@ class CNode : public AnfNode, public EffectInfoHolder { | |||
| void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; } | |||
| bool in_forward_flag() const { return in_forward_flag_; } | |||
| void set_load_flag(bool is_load) { is_load_ = is_load; } | |||
| bool get_load_flag() { return is_load_; } | |||
| VarPtr func_graph_as_var() const { return func_graph_as_var_; } | |||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | |||
| @@ -304,6 +307,7 @@ class CNode : public AnfNode, public EffectInfoHolder { | |||
| bool stop_gradient_; | |||
| bool in_forward_flag_ = false; | |||
| bool effect_handled_ = false; | |||
| bool is_load_ = false; | |||
| // inputs_value_ store cnode input value and id in pynative mode | |||
| // output_value_ store cnode value and id in pynative mode | |||
| std::vector<std::pair<ValuePtr, std::string>> inputs_value_; | |||
| @@ -68,6 +68,18 @@ AnfNodePtr FuncGraph::output() const { | |||
| } | |||
| } | |||
| const std::vector<AnfNodePtr> FuncGraph::get_inputs() const { | |||
| std::vector<AnfNodePtr> input_params; | |||
| for (auto const &node : parameters_) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto parameter = dyn_cast<Parameter>(node); | |||
| if (!parameter->has_default()) { | |||
| input_params.push_back(parameter); | |||
| } | |||
| } | |||
| return input_params; | |||
| } | |||
| ParameterPtr FuncGraph::add_parameter() { | |||
| FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>(); | |||
| ParameterPtr p = std::make_shared<Parameter>(this_func_graph); | |||
| @@ -160,6 +160,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { | |||
| abstract::AbstractFunctionPtr abstract(); | |||
| abstract::AbstractBasePtr ToAbstract() override; | |||
| // get function graph inputs, but parameters | |||
| const std::vector<AnfNodePtr> get_inputs() const; | |||
| // Return the graph's output, or nullptr if not yet deduced. | |||
| AnfNodePtr output() const; | |||
| void set_output(const AnfNodePtr &value, bool force_new_ret = false); | |||
| @@ -91,6 +91,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { | |||
| new_node->set_forward(old_node->forward().first, old_node->forward().second); | |||
| new_node->set_inputs_value(old_node->inputs_value()); | |||
| new_node->set_attrs(old_node->attrs()); | |||
| new_node->set_load_flag(old_node->get_load_flag()); | |||
| ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); | |||
| new_node->set_scope(scope); | |||
| new_node->CloneUserData(old_node); | |||
| @@ -228,17 +228,14 @@ tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::T | |||
| } | |||
| if (!tensor_proto.has_data_type()) { | |||
| MS_LOG(ERROR) << "mind_ir TensorProto has no data_type or name!"; | |||
| return nullptr; | |||
| MS_LOG(EXCEPTION) << "mind_ir TensorProto has no data_type or name!"; | |||
| } | |||
| if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) { | |||
| MS_LOG(ERROR) << "mind_ir TensorProto data_type is not support yet!"; | |||
| return nullptr; | |||
| MS_LOG(EXCEPTION) << "mind_ir TensorProto data_type is not support yet!"; | |||
| } | |||
| tensor::TensorPtr tensor_info = | |||
| std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[tensor_proto.data_type()], shape); | |||
| MS_EXCEPTION_IF_NULL(tensor_info); | |||
| return tensor_info; | |||
| } | |||
| @@ -253,9 +250,14 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, | |||
| string debug_info_name = ParseParameterName(parameter_proto.name()); | |||
| auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name); | |||
| node->set_debug_info(debug_info_ptr); | |||
| node->set_name(parameter_proto.name()); | |||
| node->set_name(debug_info_name); | |||
| tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto); | |||
| MS_EXCEPTION_IF_NULL(tensor_info); | |||
| ParamInfoPtr param_info = std::make_shared<ParamInfo>(); | |||
| param_info->set_name(debug_info_name); | |||
| tensor_info->set_param_info(param_info); | |||
| auto tensor_abstract = tensor_info->ToAbstract(); | |||
| MS_EXCEPTION_IF_NULL(tensor_abstract); | |||
| node->set_abstract(tensor_abstract); | |||
| @@ -284,13 +286,13 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi | |||
| string debug_info_name = ParseParameterName(value_proto.name()); | |||
| auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name); | |||
| node->set_debug_info(debug_info_ptr); | |||
| node->set_name(value_proto.name()); | |||
| node->set_name(debug_info_name); | |||
| const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0); | |||
| tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto); | |||
| MS_EXCEPTION_IF_NULL(tensor_info); | |||
| auto tensor_abstract = tensor_info->ToAbstract(); | |||
| MS_EXCEPTION_IF_NULL(tensor_abstract); | |||
| node->set_abstract(tensor_abstract); | |||
| anfnode_build_map_[value_proto.name()] = node; | |||
| @@ -300,15 +302,6 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi | |||
| bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | |||
| const mind_ir::GraphProto &importProto) { | |||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | |||
| MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size(); | |||
| for (int i = 0; i < importProto.parameter_size(); ++i) { | |||
| const mind_ir::TensorProto ¶meter_proto = importProto.parameter(i); | |||
| if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) { | |||
| MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; | |||
| return false; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "All inputs size is: " << importProto.input_size(); | |||
| for (int i = 0; i < importProto.input_size(); ++i) { | |||
| const mind_ir::ValueInfoProto &input_proto = importProto.input(i); | |||
| @@ -317,6 +310,15 @@ bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGr | |||
| return false; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size(); | |||
| for (int i = 0; i < importProto.parameter_size(); ++i) { | |||
| const mind_ir::TensorProto ¶meter_proto = importProto.parameter(i); | |||
| if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) { | |||
| MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| @@ -745,7 +747,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc | |||
| inputs.push_back(anfnode_build_map_[input_name]); | |||
| } | |||
| prim->set_attr("is_load", MakeValue(true)); | |||
| auto cnode_ptr = outputFuncGraph->NewCNode(prim, inputs); | |||
| MS_EXCEPTION_IF_NULL(cnode_ptr); | |||
| @@ -777,6 +779,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc | |||
| auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name); | |||
| cnode_ptr->set_debug_info(debug_info_ptr); | |||
| cnode_ptr->set_fullname_with_scope(fullname_with_scope); | |||
| cnode_ptr->set_load_flag(true); | |||
| anfnode_build_map_[node_name] = cnode_ptr; | |||
| return cnode_ptr; | |||
| @@ -804,6 +807,7 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra | |||
| inputs.push_back(maketuple_ptr); | |||
| auto return_node = outputFuncGraph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(return_node); | |||
| return_node->set_load_flag(true); | |||
| outputFuncGraph->set_return(return_node); | |||
| MS_LOG(INFO) << "Construct funcgraph finined, all success."; | |||
| } else { | |||
| @@ -812,6 +816,7 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra | |||
| inputs.push_back(cnode_ptr); | |||
| auto return_node = outputFuncGraph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(return_node); | |||
| return_node->set_load_flag(true); | |||
| outputFuncGraph->set_return(return_node); | |||
| MS_LOG(INFO) << "Construct funcgraph finined, all success!"; | |||
| } | |||
| @@ -20,7 +20,7 @@ Pre-defined building blocks or computing units to construct neural networks. | |||
| from . import layer, loss, optim, metrics, wrap, probability, sparse, dynamic_lr | |||
| from .learning_rate_schedule import * | |||
| from .dynamic_lr import * | |||
| from .cell import Cell, GraphKernel | |||
| from .cell import Cell, GraphKernel, GraphCell | |||
| from .layer import * | |||
| from .loss import * | |||
| from .optim import * | |||
| @@ -29,7 +29,7 @@ from .wrap import * | |||
| from .sparse import * | |||
| __all__ = ["Cell", "GraphKernel"] | |||
| __all__ = ["Cell", "GraphKernel", "GraphCell"] | |||
| __all__.extend(layer.__all__) | |||
| __all__.extend(loss.__all__) | |||
| __all__.extend(optim.__all__) | |||
| @@ -25,7 +25,7 @@ from mindspore import log as logger | |||
| from mindspore.common.parameter import PARAMETER_NAME_DEFAULT | |||
| from mindspore.context import ParallelMode | |||
| from .. import context | |||
| from .._c_expression import init_pipeline, Cell_ | |||
| from .._c_expression import init_pipeline, Cell_, FuncGraph | |||
| from .._checkparam import Validator | |||
| from ..common import dtype as mstype | |||
| from ..common.api import _executor, _pynative_exec | |||
| @@ -1191,3 +1191,39 @@ class GraphKernel(Cell): | |||
| def construct(self): | |||
| raise NotImplementedError | |||
| class GraphCell(Cell): | |||
| """ | |||
| Base class for running the graph loaded from a MindIR. | |||
| This feature is still under development. Currently `GraphCell` do not support modifying the structure of the | |||
| diagram, and can only use data that shape and type are the same as the input when exporting the MindIR. | |||
| Args: | |||
| graph (object): A compiled graph loaded from MindIR. | |||
| Examples: | |||
| >>> import numpy as np | |||
| >>> import mindspore.nn as nn | |||
| >>> from mindspore import Tensor | |||
| >>> from mindspore.train import export, load | |||
| >>> | |||
| >>> net = nn.Conv2d(1, 1, kernel_size=3) | |||
| >>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) | |||
| >>> export(net, input, file_name="net", file_format="MINDIR") | |||
| >>> graph = load("net.mindir") | |||
| >>> net = nn.GraphCell(graph) | |||
| >>> output = net(input) | |||
| """ | |||
| def __init__(self, graph): | |||
| super(GraphCell, self).__init__(auto_prefix=True) | |||
| if not isinstance(graph, FuncGraph): | |||
| raise TypeError(f"graph must be a FuncGraph loaded from MindIR, but got {type(graph)}.") | |||
| self.graph = graph | |||
| def construct(self, *inputs): | |||
| return self.graph(*inputs) | |||
| def __call__(self, *inputs): | |||
| return self.compile_and_run(*inputs) | |||
| @@ -22,10 +22,10 @@ from .dataset_helper import DatasetHelper, connect_network_with_dataset | |||
| from . import amp | |||
| from .amp import build_train_network | |||
| from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager | |||
| from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, parse_print,\ | |||
| from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, load, parse_print,\ | |||
| build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint | |||
| __all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager", | |||
| "FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint", | |||
| "load_param_into_net", "export", "parse_print", "build_searched_strategy", "merge_sliced_parameter", | |||
| "load_param_into_net", "export", "load", "parse_print", "build_searched_strategy", "merge_sliced_parameter", | |||
| "load_distributed_checkpoint"] | |||
| @@ -139,10 +139,22 @@ class Model: | |||
| self._global_rank = _get_global_rank() | |||
| self._parameter_broadcast = _get_parameter_broadcast() | |||
| self._check_for_graph_cell(kwargs) | |||
| self._train_network = self._build_train_network() | |||
| self._build_eval_network(metrics, eval_network, eval_indexes) | |||
| self._build_predict_network() | |||
| def _check_for_graph_cell(self, kwargs): | |||
| if not isinstance(self._network, nn.GraphCell): | |||
| return | |||
| if self._amp_level != "O0": | |||
| logger.warning("amp_level will not work when network is a GraphCell.") | |||
| if self._loss_fn is not None or self._optimizer is not None: | |||
| raise ValueError("Currently loss_fn and optimizer should be None when network is a GraphCell. ") | |||
| if kwargs: | |||
| raise ValueError("Currently kwargs should be empty when network is a GraphCell. ") | |||
| def _process_amp_args(self, kwargs): | |||
| if self._amp_level in ["O0", "O3"]: | |||
| self._keep_bn_fp32 = False | |||
| @@ -586,6 +598,8 @@ class Model: | |||
| >>> model.train(2, dataset) | |||
| """ | |||
| dataset_sink_mode = Validator.check_bool(dataset_sink_mode) | |||
| if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode is True: | |||
| raise ValueError("Sink mode is currently not supported when training with a GraphCell.") | |||
| Validator.check_is_int(sink_size) | |||
| dataset_size = train_dataset.get_dataset_size() | |||
| if dataset_size == 0: | |||
| @@ -704,9 +718,12 @@ class Model: | |||
| >>> acc = model.eval(dataset, dataset_sink_mode=False) | |||
| """ | |||
| dataset_sink_mode = Validator.check_bool(dataset_sink_mode) | |||
| _device_number_check(self._parallel_mode, self._device_number) | |||
| if not self._metric_fns: | |||
| raise ValueError("metric fn can not be None or empty.") | |||
| if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode is True: | |||
| raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.") | |||
| cb_params = _InternalCallbackParam() | |||
| cb_params.eval_network = self._eval_network | |||
| @@ -38,6 +38,7 @@ from mindspore._checkparam import check_input_data, Validator | |||
| from mindspore.compression.export import quant_export | |||
| from mindspore.parallel._tensor import _load_tensor | |||
| from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices | |||
| from .._c_expression import load_mindir | |||
| tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, | |||
| @@ -228,6 +229,49 @@ def _check_param_prefix(filter_prefix, param_name): | |||
| return False | |||
| def load(file_name): | |||
| """ | |||
| Load MindIR. | |||
| The returned object can be executed by a `GraphCell`. However, there are some limitations to the current use | |||
| of `GraphCell`, see class :class:`mindspore.nn.GraphCell` for more details. | |||
| Args: | |||
| file_name (str): MindIR file name. | |||
| Returns: | |||
| Object, a compiled graph that can executed by `GraphCell`. | |||
| Raises: | |||
| ValueError: MindIR file is incorrect. | |||
| Examples: | |||
| >>> import numpy as np | |||
| >>> import mindspore.nn as nn | |||
| >>> from mindspore import Tensor | |||
| >>> from mindspore.train import export, load | |||
| >>> | |||
| >>> net = nn.Conv2d(1, 1, kernel_size=3) | |||
| >>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) | |||
| >>> export(net, input, file_name="net", file_format="MINDIR") | |||
| >>> graph = load("net.mindir") | |||
| >>> net = nn.GraphCell(graph) | |||
| >>> output = net(input) | |||
| """ | |||
| if not isinstance(file_name, str): | |||
| raise ValueError("The file name must be string.") | |||
| if not os.path.exists(file_name): | |||
| raise ValueError("The file is not exist.") | |||
| if not file_name.endswith(".mindir"): | |||
| raise ValueError("The MindIR should end with mindir, please input the correct file name.") | |||
| logger.info("Execute the process of loading mindir.") | |||
| graph = load_mindir(file_name) | |||
| if graph is None: | |||
| raise RuntimeError("Load MindIR failed.") | |||
| return graph | |||
| def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None): | |||
| """ | |||
| Loads checkpoint info from a specified file. | |||
| @@ -22,7 +22,7 @@ from mindspore.common.initializer import TruncatedNormal | |||
| from mindspore.common.parameter import ParameterTuple | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.train.serialization import export | |||
| from mindspore.train.serialization import export, load | |||
| def weight_variable(): | |||
| @@ -112,3 +112,26 @@ def test_export_lenet_grad_mindir(): | |||
| export(net, predict, label, file_name="lenet_grad", file_format='MINDIR') | |||
| verify_name = "lenet_grad.mindir" | |||
| assert os.path.exists(verify_name) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_load_mindir_and_run(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| network = LeNet5() | |||
| network.set_train() | |||
| inputs0 = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01) | |||
| outputs0 = network(inputs0) | |||
| inputs = Tensor(np.zeros([32, 1, 32, 32]).astype(np.float32)) | |||
| export(network, inputs, file_name="test_lenet_load", file_format='MINDIR') | |||
| mindir_name = "test_lenet_load.mindir" | |||
| assert os.path.exists(mindir_name) | |||
| graph = load(mindir_name) | |||
| loaded_net = nn.GraphCell(graph) | |||
| outputs_after_load = loaded_net(inputs0) | |||
| assert np.allclose(outputs0.asnumpy(), outputs_after_load.asnumpy()) | |||