Browse Source

fix mindspore models runtime on_device

tags/v0.7.0-beta
yankai 5 years ago
parent
commit
b3468fab89
6 changed files with 32 additions and 7 deletions
  1. +3
    -1
      mindspore/lite/src/common/anf_exporter/anf_exporter.cc
  2. +20
    -2
      mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc
  3. +2
    -1
      mindspore/lite/src/common/anf_importer/import_from_protobuf.cc
  4. +2
    -1
      mindspore/lite/src/common/anf_importer/import_from_protobuf.h
  5. +1
    -1
      mindspore/lite/tools/common/node_util.h
  6. +4
    -1
      mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc

+ 3
- 1
mindspore/lite/src/common/anf_exporter/anf_exporter.cc View File

@@ -385,9 +385,11 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::vector<schem
int i = 0;
for (auto outputTensor : outputTensors) {
std::string name = cnodeName + "_o:" + std::to_string(i);
auto msTensor = new schema::TensorT();
msTensor->nodeType = schema::NodeType_Parameter;
nodeIdMap[name] = graph->allTensors.size();
fbnode->outputIndex.emplace_back(graph->allTensors.size());
graph->allTensors.emplace_back(outputTensor);
graph->allTensors.emplace_back(msTensor);
i++;
}
return;


+ 20
- 2
mindspore/lite/src/common/anf_exporter/anf_populater/anf_reshape_populater.cc View File

@@ -23,10 +23,28 @@
namespace mindspore::lite {
int mindspore::lite::AnfReshapePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto attr = std::make_unique<schema::FlattenT>();
auto attr = std::make_unique<schema::ReshapeT>();
MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree);
auto inputNode = cnodePtr->input(kAnfPopulaterTwo);
if (inputNode->isa<ValueNode>()) {
auto valueNode = inputNode->cast<ValueNodePtr>();
MS_ASSERT(valueNode != nullptr);
auto val = valueNode->value();
MS_ASSERT(val != nullptr);
if (val->isa<ValueTuple>()) {
auto tuple = val->cast<ValueTuplePtr>();
MS_ASSERT(tuple != nullptr);
for (size_t i = 0; i < tuple->size(); ++i) {
auto elem = tuple->value()[i]->cast<Int32ImmPtr>();
MS_ASSERT(elem != nullptr);
attr->shape.emplace_back(static_cast<int>(elem->value()));
}
}
}

node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Flatten;
node->primitive->value.type = schema::PrimitiveType_Reshape;
node->primitive->value.value = attr.release();
return 0;
}


+ 2
- 1
mindspore/lite/src/common/anf_importer/import_from_protobuf.cc View File

@@ -639,7 +639,7 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
return true;
}
#endif
#else

#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \
@@ -1108,6 +1108,7 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
return true;
}
#endif

bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);


+ 2
- 1
mindspore/lite/src/common/anf_importer/import_from_protobuf.h View File

@@ -77,7 +77,7 @@ class AnfImporterFromProtobuf : public AnfImporter {
const onnx::TensorProto &attr_tensor);
std::unordered_map<std::string, abstract::AbstractTensorPtr>
GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
#endif
#else
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto);
@@ -100,6 +100,7 @@ class AnfImporterFromProtobuf : public AnfImporter {
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto);

#endif


private:


+ 1
- 1
mindspore/lite/tools/common/node_util.h View File

@@ -232,7 +232,7 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in
if (type == kCKHW2HWCK) {
p2Buff =
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kKCHW2KHWC) {
} else if (type == kCKHW2KHWC) {
p2Buff =
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterK) + (w * filterC) + (c));
} else {


+ 4
- 1
mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc View File

@@ -350,6 +350,9 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
// todo(00445839): consider varible weight condition
}
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be CKHW
if (graphNode->subGraph->fmkType == converter::FmkType_MS) {
weightTensor->format = schema::Format_CKHW;
}
if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC);
} else if (weightTensor->format == schema::Format_KCHW) {
@@ -362,7 +365,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
}
if (status == 0) {
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_CKHW;
weightTensor->format = schema::Format_KHWC;
} else {
MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str();
// todo(00445839): consider varible weight condition


Loading…
Cancel
Save