Browse Source

develop mindir load and run

tags/v1.2.0-rc1
wangnan39@huawei.com 5 years ago
parent
commit
0681c2597b
20 changed files with 351 additions and 55 deletions
  1. +74
    -2
      mindspore/ccsrc/pipeline/jit/action.cc
  2. +1
    -0
      mindspore/ccsrc/pipeline/jit/init.cc
  3. +57
    -29
      mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
  4. +44
    -0
      mindspore/ccsrc/pipeline/jit/parse/resolve.cc
  5. +3
    -0
      mindspore/ccsrc/pipeline/jit/pipeline.cc
  6. +1
    -0
      mindspore/ccsrc/pipeline/jit/pipeline.h
  7. +3
    -0
      mindspore/ccsrc/pipeline/jit/validator.cc
  8. +4
    -0
      mindspore/core/ir/anf.h
  9. +12
    -0
      mindspore/core/ir/func_graph.cc
  10. +2
    -0
      mindspore/core/ir/func_graph.h
  11. +1
    -0
      mindspore/core/ir/func_graph_cloner.cc
  12. +23
    -18
      mindspore/core/load_mindir/anf_model_parser.cc
  13. +2
    -2
      mindspore/nn/__init__.py
  14. +37
    -1
      mindspore/nn/cell.py
  15. +2
    -2
      mindspore/train/__init__.py
  16. +17
    -0
      mindspore/train/model.py
  17. +44
    -0
      mindspore/train/serialization.py
  18. +24
    -1
      tests/st/export_and_load/test_train_mindir.py
  19. +0
    -0
      tests/st/export_and_load/text_air.py
  20. +0
    -0
      tests/st/export_and_load/text_lite_mindir.py

+ 74
- 2
mindspore/ccsrc/pipeline/jit/action.cc View File

@@ -66,6 +66,10 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
for (auto &node : manager->all_nodes()) {
MS_EXCEPTION_IF_NULL(node);
const AbstractBasePtr &prev_inferred = node->abstract();
// Keep previous inferred value for CNode if is loaded from MindIR.
if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) {
continue;
}
// Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
node->set_abstract(nullptr);
@@ -113,6 +117,69 @@ FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
return ret;
}

const FuncGraphPtr GetLoadedGraph(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
auto manager = res->manager();
MS_EXCEPTION_IF_NULL(manager);
FuncGraphPtr loaded_graph = nullptr;
size_t loaded_graph_num = 0;
auto all_graphs = manager->func_graphs();
for (auto &graph : all_graphs) {
MS_EXCEPTION_IF_NULL(graph);
if (graph->has_attr("is_load")) {
loaded_graph = graph;
loaded_graph_num += 1;
}
}
if (loaded_graph_num == 0) {
return nullptr;
}
if (loaded_graph_num == 1) {
return loaded_graph;
}
MS_LOG(EXCEPTION) << "The loaded sub graph currently should less than 2, but got " << loaded_graph_num;
}

void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &loaded_graph) {
MS_EXCEPTION_IF_NULL(res);
auto manager = res->manager();
MS_EXCEPTION_IF_NULL(manager);
FuncGraphPtr root_graph = *(manager->roots().begin());
auto root_inputs = root_graph->get_inputs();
auto loaded_inputs = loaded_graph->get_inputs();

size_t root_inputs_num = root_inputs.size();
size_t loaded_inputs_num = loaded_inputs.size();
if (root_inputs_num != loaded_inputs_num) {
MS_LOG(EXCEPTION) << "The inputs number " << root_inputs_num << " not equal to the inputs number of loaded graph "
<< loaded_inputs_num;
}
for (size_t index = 0; index < root_inputs_num; index++) {
auto root_input = root_inputs[index];
auto loaded_input = loaded_inputs[index];

auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(root_input->Shape());
auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(loaded_input->Shape());
auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast<Type>(root_input->Type());
auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast<Type>(loaded_input->Type());
MS_EXCEPTION_IF_NULL(root_shape);
MS_EXCEPTION_IF_NULL(loaded_shape);
MS_EXCEPTION_IF_NULL(root_type);
MS_EXCEPTION_IF_NULL(loaded_type);

if (root_shape->shape() != loaded_shape->shape()) {
MS_EXCEPTION(ValueError) << "The " << index
<< " th input shape differ from loaded graph. Input shape: " << root_shape->ToString()
<< ", input shape of loaded graph: " << loaded_shape->ToString();
}
if (root_type->type_id() != loaded_type->type_id()) {
MS_EXCEPTION(TypeError) << "The " << std::to_string(index)
<< " th input type differ from loaded graph. Input type: " << root_type->ToString()
<< ", input type of loaded graph: " << loaded_type->ToString();
}
}
}

bool ParseAction(const ResourcePtr &res) {
if (!res->input()) {
MS_LOG(EXCEPTION) << "Parse error";
@@ -255,12 +322,14 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "AbstractSpecialize error";
}

FuncGraphPtr func_graph = res->func_graph();
abstract::AbstractBasePtrList args_spec = res->args_spec();
auto context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
context->ParallelParameterContextInitShape(func_graph);

// get original loaded graph to check inputs later
auto loaded_graph_ptr = GetLoadedGraph(res);
// suppose that there is not KeywordArgument for the top graph
// get the hyper parameter
for (const auto &param : func_graph->parameters()) {
@@ -294,7 +363,10 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
}
}
}

// check input after abstract when there is a loaded graph
if (loaded_graph_ptr != nullptr) {
CheckRootInputShapeAndType(res, loaded_graph_ptr);
}
MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true);
return true;
}


+ 1
- 0
mindspore/ccsrc/pipeline/jit/init.cc View File

@@ -111,6 +111,7 @@ PYBIND11_MODULE(_c_expression, m) {
(void)m.def("init_pipeline", &mindspore::pipeline::InitPipeline, "Init Pipeline.");

(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
(py::object) m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), "Load model as Graph.");

(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")


+ 57
- 29
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc View File

@@ -203,6 +203,19 @@ bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_
return true;
}

bool ConvertFuncGraph(const py::object &obj, ValuePtr *const data) {
MS_LOG(DEBUG) << "Converting FuncGraph object";
auto func_graph = obj.cast<FuncGraphPtr>();
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Resolve FuncGraph error, get ptr is null";
return false;
}
auto new_fg = BasicClone(func_graph);
new_fg->set_attr("is_load", MakeValue(true));
*data = new_fg;
return true;
}

bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
MS_LOG(DEBUG) << "Converting slice object";

@@ -368,47 +381,21 @@ bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype
}
} // namespace

bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) {
// check parameter valid
if (data == nullptr) {
MS_LOG(ERROR) << "Data is null pointer";
return false;
}

bool ret = true;
bool ConvertSingleData(const py::object &obj, ValuePtr *const data) {
MS_EXCEPTION_IF_NULL(data);
ValuePtr converted = nullptr;
if (py::isinstance<py::none>(obj)) {
converted = kNone;
} else if (py::isinstance<py::bool_>(obj)) {
converted = std::make_shared<BoolImm>(py::cast<bool>(obj));
} else if (py::isinstance<py::int_>(obj)) {
ret = ConvertIntegerWithType(py::cast<int64_t>(obj), &converted, dtype);
} else if (py::isinstance<py::float_>(obj)) {
ret = ConvertFloatWithType(py::cast<float>(obj), &converted, dtype);
} else if (py::isinstance<py::str>(obj)) {
converted = std::make_shared<StringImm>(py::cast<std::string>(obj));
} else if (py::isinstance<py::dict>(obj)) {
ret = ConvertDict(obj, &converted, use_signature);
} else if (py::isinstance<py::slice>(obj)) {
ret = ConvertSlice(obj, &converted);
} else if (py::isinstance<py::ellipsis>(obj)) {
converted = kEllipsis;
} else if (py::isinstance<py::tuple>(obj)) {
ret = ConvertTuple(obj, &converted, use_signature);
} else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) {
ret = ConvertCellList(obj, &converted, use_signature);
} else if (py::isinstance<Cell>(obj)) {
return ConvertCellObjToFuncGraph(obj.cast<CellPtr>(), data);
} else if (py::isinstance<py::list>(obj)) {
ret = ConvertList(obj, &converted, use_signature);
} else if (py::isinstance<py::module>(obj)) {
ConvertNameSpace(obj, &converted);
} else if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) {
ConvertDataClass(obj, &converted);
} else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) {
ret = ConvertPrimitive(obj, &converted, use_signature);
} else if (py::isinstance<MetaFuncGraph>(obj)) {
ret = ConvertMetaFuncGraph(obj, &converted, use_signature);
} else if (py::isinstance<Type>(obj)) {
converted = obj.cast<TypePtr>();
} else if (py::isinstance<Tensor>(obj)) {
@@ -425,9 +412,50 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
} else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) {
converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
} else {
ret = ConvertOtherObj(obj, &converted);
return false;
}
*data = converted;
return true;
}

bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) {
// check parameter valid
if (data == nullptr) {
MS_LOG(ERROR) << "Data is null pointer";
return false;
}

ValuePtr converted = nullptr;
bool ret = ConvertSingleData(obj, &converted);
if (ret) {
*data = converted;
return true;
}
if (py::isinstance<py::int_>(obj)) {
ret = ConvertIntegerWithType(py::cast<int64_t>(obj), &converted, dtype);
} else if (py::isinstance<py::float_>(obj)) {
ret = ConvertFloatWithType(py::cast<float>(obj), &converted, dtype);
} else if (py::isinstance<py::dict>(obj)) {
ret = ConvertDict(obj, &converted, use_signature);
} else if (py::isinstance<py::slice>(obj)) {
ret = ConvertSlice(obj, &converted);
} else if (py::isinstance<py::tuple>(obj)) {
ret = ConvertTuple(obj, &converted, use_signature);
} else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) {
ret = ConvertCellList(obj, &converted, use_signature);
} else if (py::isinstance<Cell>(obj)) {
return ConvertCellObjToFuncGraph(obj.cast<CellPtr>(), data);
} else if (py::isinstance<py::list>(obj)) {
ret = ConvertList(obj, &converted, use_signature);
} else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) {
ret = ConvertPrimitive(obj, &converted, use_signature);
} else if (py::isinstance<MetaFuncGraph>(obj)) {
ret = ConvertMetaFuncGraph(obj, &converted, use_signature);
} else if (py::isinstance<FuncGraph>(obj)) {
ret = ConvertFuncGraph(obj, &converted);
} else {
ret = ConvertOtherObj(obj, &converted);
}
*data = converted;
return ret;
}


+ 44
- 0
mindspore/ccsrc/pipeline/jit/parse/resolve.cc View File

@@ -113,6 +113,49 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
return para_node;
}

void BroadenCNodeAbstract(const FuncGraphPtr &func_graph) {
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
for (const AnfNodePtr &node : nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto abstract = node->abstract();
if (abstract != nullptr) {
node->set_abstract(abstract->Broaden());
}
}
}

void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) {
if (!value->isa<FuncGraph>()) {
return;
}
auto resolved_graph = value->cast<FuncGraphPtr>();
MS_EXCEPTION_IF_NULL(resolved_graph);
if (!resolved_graph->has_attr("is_load")) {
return;
}
auto top_graph = Parser::GetTopFuncGraph();
std::vector<AnfNodePtr> input_params;
for (auto const &param : resolved_graph->parameters()) {
auto param_ptr = dyn_cast<Parameter>(param);
MS_EXCEPTION_IF_NULL(param_ptr);
if (param_ptr->has_default()) {
param_ptr->set_func_graph(top_graph);
func_graph->add_used_global_parameters(param_ptr);

// update top_graph
top_graph->add_parameter(param_ptr);
size_t hyper_param_count = top_graph->hyper_param_count();
top_graph->set_hyper_param_count(hyper_param_count + 1);
} else {
input_params.push_back(param_ptr);
}
}
resolved_graph->set_parameters(input_params);
BroadenCNodeAbstract(resolved_graph);
}

bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) {
AnfNodePtr output = nullptr;
if (py::hasattr(obj, "__parameter__") && py::isinstance<tensor::MetaTensor>(obj)) {
@@ -146,6 +189,7 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj,
return false;
}
MS_EXCEPTION_IF_NULL(convert_result);
ConvertLoadedGraph(func_graph, convert_result);
output = NewValueNode(convert_result);
if (convert_result->isa<tensor::Tensor>()) {
output = GetMixedPrecisionCastHelp(func_graph, output);


+ 3
- 0
mindspore/ccsrc/pipeline/jit/pipeline.cc View File

@@ -48,6 +48,7 @@
#include "pybind_api/pybind_patch.h"
#include "utils/shape_utils.h"
#include "utils/info.h"
#include "load_mindir/load_model.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/constants.h"
#include "ps/util.h"
@@ -1096,6 +1097,8 @@ void ExportGraph(const std::string &file_name, const std::string &, const std::s
#endif
}

FuncGraphPtr LoadMindIR(const std::string &file_name) { return mindspore::LoadMindIR(file_name); }

void ReleaseGeTsd() {
auto context_ptr = MsContext::GetInstance();
if (context_ptr != nullptr) {


+ 1
- 0
mindspore/ccsrc/pipeline/jit/pipeline.h View File

@@ -140,6 +140,7 @@ void ClearResAtexit();
void ReleaseGeTsd();

void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase);
FuncGraphPtr LoadMindIR(const std::string &file_name);

// init and exec dataset sub graph
bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,


+ 3
- 0
mindspore/ccsrc/pipeline/jit/validator.cc View File

@@ -51,6 +51,9 @@ void ValidateOperation(const AnfNodePtr &node) {
if (abstract::IsInWhiteList(prim)) {
return;
}
if (prim->HasAttr("is_load")) {
return;
}
if (prim->HasPyEvaluator()) {
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
return;


+ 4
- 0
mindspore/core/ir/anf.h View File

@@ -273,6 +273,9 @@ class CNode : public AnfNode, public EffectInfoHolder {
void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; }
bool in_forward_flag() const { return in_forward_flag_; }

void set_load_flag(bool is_load) { is_load_ = is_load; }
bool get_load_flag() { return is_load_; }

VarPtr func_graph_as_var() const { return func_graph_as_var_; }

const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
@@ -304,6 +307,7 @@ class CNode : public AnfNode, public EffectInfoHolder {
bool stop_gradient_;
bool in_forward_flag_ = false;
bool effect_handled_ = false;
bool is_load_ = false;
// inputs_value_ store cnode input value and id in pynative mode
// output_value_ store cnode value and id in pynative mode
std::vector<std::pair<ValuePtr, std::string>> inputs_value_;


+ 12
- 0
mindspore/core/ir/func_graph.cc View File

@@ -68,6 +68,18 @@ AnfNodePtr FuncGraph::output() const {
}
}

const std::vector<AnfNodePtr> FuncGraph::get_inputs() const {
std::vector<AnfNodePtr> input_params;
for (auto const &node : parameters_) {
MS_EXCEPTION_IF_NULL(node);
auto parameter = dyn_cast<Parameter>(node);
if (!parameter->has_default()) {
input_params.push_back(parameter);
}
}
return input_params;
}

ParameterPtr FuncGraph::add_parameter() {
FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
ParameterPtr p = std::make_shared<Parameter>(this_func_graph);


+ 2
- 0
mindspore/core/ir/func_graph.h View File

@@ -160,6 +160,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
abstract::AbstractFunctionPtr abstract();
abstract::AbstractBasePtr ToAbstract() override;

// get function graph inputs, but parameters
const std::vector<AnfNodePtr> get_inputs() const;
// Return the graph's output, or nullptr if not yet deduced.
AnfNodePtr output() const;
void set_output(const AnfNodePtr &value, bool force_new_ret = false);


+ 1
- 0
mindspore/core/ir/func_graph_cloner.cc View File

@@ -91,6 +91,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
new_node->set_forward(old_node->forward().first, old_node->forward().second);
new_node->set_inputs_value(old_node->inputs_value());
new_node->set_attrs(old_node->attrs());
new_node->set_load_flag(old_node->get_load_flag());
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_node->set_scope(scope);
new_node->CloneUserData(old_node);


+ 23
- 18
mindspore/core/load_mindir/anf_model_parser.cc View File

@@ -228,17 +228,14 @@ tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::T
}

if (!tensor_proto.has_data_type()) {
MS_LOG(ERROR) << "mind_ir TensorProto has no data_type or name!";
return nullptr;
MS_LOG(EXCEPTION) << "mind_ir TensorProto has no data_type or name!";
}
if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "mind_ir TensorProto data_type is not support yet!";
return nullptr;
MS_LOG(EXCEPTION) << "mind_ir TensorProto data_type is not support yet!";
}

tensor::TensorPtr tensor_info =
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[tensor_proto.data_type()], shape);
MS_EXCEPTION_IF_NULL(tensor_info);
return tensor_info;
}

@@ -253,9 +250,14 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
string debug_info_name = ParseParameterName(parameter_proto.name());
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
node->set_debug_info(debug_info_ptr);
node->set_name(parameter_proto.name());
node->set_name(debug_info_name);

tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto);
MS_EXCEPTION_IF_NULL(tensor_info);
ParamInfoPtr param_info = std::make_shared<ParamInfo>();
param_info->set_name(debug_info_name);
tensor_info->set_param_info(param_info);

auto tensor_abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(tensor_abstract);
node->set_abstract(tensor_abstract);
@@ -284,13 +286,13 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
string debug_info_name = ParseParameterName(value_proto.name());
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
node->set_debug_info(debug_info_ptr);
node->set_name(value_proto.name());
node->set_name(debug_info_name);

const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);

tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto);
MS_EXCEPTION_IF_NULL(tensor_info);
auto tensor_abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(tensor_abstract);
node->set_abstract(tensor_abstract);

anfnode_build_map_[value_proto.name()] = node;
@@ -300,15 +302,6 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
const mind_ir::GraphProto &importProto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size();
for (int i = 0; i < importProto.parameter_size(); ++i) {
const mind_ir::TensorProto &parameter_proto = importProto.parameter(i);
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) {
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
return false;
}
}

MS_LOG(INFO) << "All inputs size is: " << importProto.input_size();
for (int i = 0; i < importProto.input_size(); ++i) {
const mind_ir::ValueInfoProto &input_proto = importProto.input(i);
@@ -317,6 +310,15 @@ bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGr
return false;
}
}

MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size();
for (int i = 0; i < importProto.parameter_size(); ++i) {
const mind_ir::TensorProto &parameter_proto = importProto.parameter(i);
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) {
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
return false;
}
}
return true;
}

@@ -745,7 +747,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc

inputs.push_back(anfnode_build_map_[input_name]);
}
prim->set_attr("is_load", MakeValue(true));
auto cnode_ptr = outputFuncGraph->NewCNode(prim, inputs);
MS_EXCEPTION_IF_NULL(cnode_ptr);

@@ -777,6 +779,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
cnode_ptr->set_debug_info(debug_info_ptr);
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
cnode_ptr->set_load_flag(true);

anfnode_build_map_[node_name] = cnode_ptr;
return cnode_ptr;
@@ -804,6 +807,7 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra
inputs.push_back(maketuple_ptr);
auto return_node = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(return_node);
return_node->set_load_flag(true);
outputFuncGraph->set_return(return_node);
MS_LOG(INFO) << "Construct funcgraph finined, all success.";
} else {
@@ -812,6 +816,7 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra
inputs.push_back(cnode_ptr);
auto return_node = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(return_node);
return_node->set_load_flag(true);
outputFuncGraph->set_return(return_node);
MS_LOG(INFO) << "Construct funcgraph finined, all success!";
}


+ 2
- 2
mindspore/nn/__init__.py View File

@@ -20,7 +20,7 @@ Pre-defined building blocks or computing units to construct neural networks.
from . import layer, loss, optim, metrics, wrap, probability, sparse, dynamic_lr
from .learning_rate_schedule import *
from .dynamic_lr import *
from .cell import Cell, GraphKernel
from .cell import Cell, GraphKernel, GraphCell
from .layer import *
from .loss import *
from .optim import *
@@ -29,7 +29,7 @@ from .wrap import *
from .sparse import *


__all__ = ["Cell", "GraphKernel"]
__all__ = ["Cell", "GraphKernel", "GraphCell"]
__all__.extend(layer.__all__)
__all__.extend(loss.__all__)
__all__.extend(optim.__all__)


+ 37
- 1
mindspore/nn/cell.py View File

@@ -25,7 +25,7 @@ from mindspore import log as logger
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
from mindspore.context import ParallelMode
from .. import context
from .._c_expression import init_pipeline, Cell_
from .._c_expression import init_pipeline, Cell_, FuncGraph
from .._checkparam import Validator
from ..common import dtype as mstype
from ..common.api import _executor, _pynative_exec
@@ -1191,3 +1191,39 @@ class GraphKernel(Cell):

def construct(self):
raise NotImplementedError


class GraphCell(Cell):
"""
Base class for running the graph loaded from a MindIR.

This feature is still under development. Currently `GraphCell` do not support modifying the structure of the
diagram, and can only use data that shape and type are the same as the input when exporting the MindIR.

Args:
graph (object): A compiled graph loaded from MindIR.

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor
>>> from mindspore.train import export, load
>>>
>>> net = nn.Conv2d(1, 1, kernel_size=3)
>>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
>>> export(net, input, file_name="net", file_format="MINDIR")
>>> graph = load("net.mindir")
>>> net = nn.GraphCell(graph)
>>> output = net(input)
"""
def __init__(self, graph):
super(GraphCell, self).__init__(auto_prefix=True)
if not isinstance(graph, FuncGraph):
raise TypeError(f"graph must be a FuncGraph loaded from MindIR, but got {type(graph)}.")
self.graph = graph

def construct(self, *inputs):
return self.graph(*inputs)

def __call__(self, *inputs):
return self.compile_and_run(*inputs)

+ 2
- 2
mindspore/train/__init__.py View File

@@ -22,10 +22,10 @@ from .dataset_helper import DatasetHelper, connect_network_with_dataset
from . import amp
from .amp import build_train_network
from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, parse_print,\
from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, load, parse_print,\
build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint

__all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager",
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
"load_param_into_net", "export", "parse_print", "build_searched_strategy", "merge_sliced_parameter",
"load_param_into_net", "export", "load", "parse_print", "build_searched_strategy", "merge_sliced_parameter",
"load_distributed_checkpoint"]

+ 17
- 0
mindspore/train/model.py View File

@@ -139,10 +139,22 @@ class Model:
self._global_rank = _get_global_rank()
self._parameter_broadcast = _get_parameter_broadcast()

self._check_for_graph_cell(kwargs)
self._train_network = self._build_train_network()
self._build_eval_network(metrics, eval_network, eval_indexes)
self._build_predict_network()

def _check_for_graph_cell(self, kwargs):
if not isinstance(self._network, nn.GraphCell):
return
if self._amp_level != "O0":
logger.warning("amp_level will not work when network is a GraphCell.")

if self._loss_fn is not None or self._optimizer is not None:
raise ValueError("Currently loss_fn and optimizer should be None when network is a GraphCell. ")
if kwargs:
raise ValueError("Currently kwargs should be empty when network is a GraphCell. ")

def _process_amp_args(self, kwargs):
if self._amp_level in ["O0", "O3"]:
self._keep_bn_fp32 = False
@@ -586,6 +598,8 @@ class Model:
>>> model.train(2, dataset)
"""
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode is True:
raise ValueError("Sink mode is currently not supported when training with a GraphCell.")
Validator.check_is_int(sink_size)
dataset_size = train_dataset.get_dataset_size()
if dataset_size == 0:
@@ -704,9 +718,12 @@ class Model:
>>> acc = model.eval(dataset, dataset_sink_mode=False)
"""
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)

_device_number_check(self._parallel_mode, self._device_number)
if not self._metric_fns:
raise ValueError("metric fn can not be None or empty.")
if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode is True:
raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")

cb_params = _InternalCallbackParam()
cb_params.eval_network = self._eval_network


+ 44
- 0
mindspore/train/serialization.py View File

@@ -38,6 +38,7 @@ from mindspore._checkparam import check_input_data, Validator
from mindspore.compression.export import quant_export
from mindspore.parallel._tensor import _load_tensor
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
from .._c_expression import load_mindir


tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
@@ -228,6 +229,49 @@ def _check_param_prefix(filter_prefix, param_name):
return False


def load(file_name):
"""
Load MindIR.

The returned object can be executed by a `GraphCell`. However, there are some limitations to the current use
of `GraphCell`, see class :class:`mindspore.nn.GraphCell` for more details.

Args:
file_name (str): MindIR file name.

Returns:
Object, a compiled graph that can executed by `GraphCell`.

Raises:
ValueError: MindIR file is incorrect.

Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor
>>> from mindspore.train import export, load
>>>
>>> net = nn.Conv2d(1, 1, kernel_size=3)
>>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
>>> export(net, input, file_name="net", file_format="MINDIR")
>>> graph = load("net.mindir")
>>> net = nn.GraphCell(graph)
>>> output = net(input)
"""
if not isinstance(file_name, str):
raise ValueError("The file name must be string.")
if not os.path.exists(file_name):
raise ValueError("The file is not exist.")
if not file_name.endswith(".mindir"):
raise ValueError("The MindIR should end with mindir, please input the correct file name.")

logger.info("Execute the process of loading mindir.")
graph = load_mindir(file_name)
if graph is None:
raise RuntimeError("Load MindIR failed.")
return graph


def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None):
"""
Loads checkpoint info from a specified file.


tests/st/export/test_train_mindir.py → tests/st/export_and_load/test_train_mindir.py View File

@@ -22,7 +22,7 @@ from mindspore.common.initializer import TruncatedNormal
from mindspore.common.parameter import ParameterTuple
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.train.serialization import export
from mindspore.train.serialization import export, load


def weight_variable():
@@ -112,3 +112,26 @@ def test_export_lenet_grad_mindir():
export(net, predict, label, file_name="lenet_grad", file_format='MINDIR')
verify_name = "lenet_grad.mindir"
assert os.path.exists(verify_name)


@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_load_mindir_and_run():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
network = LeNet5()
network.set_train()

inputs0 = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
outputs0 = network(inputs0)

inputs = Tensor(np.zeros([32, 1, 32, 32]).astype(np.float32))
export(network, inputs, file_name="test_lenet_load", file_format='MINDIR')
mindir_name = "test_lenet_load.mindir"
assert os.path.exists(mindir_name)

graph = load(mindir_name)
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(inputs0)
assert np.allclose(outputs0.asnumpy(), outputs_after_load.asnumpy())

tests/st/export/text_air.py → tests/st/export_and_load/text_air.py View File


tests/st/export/text_lite_mindir.py → tests/st/export_and_load/text_lite_mindir.py View File


Loading…
Cancel
Save