Browse Source

fixed format trans

tags/v0.7.0-beta
kai00 5 years ago
parent
commit
9638139e27
4 changed files with 4 additions and 3 deletions
  1. +1
    -0
      mindspore/lite/src/common/anf_importer/import_from_protobuf.cc
  2. +1
    -1
      mindspore/lite/tools/common/node_util.cc
  3. +1
    -1
      mindspore/lite/tools/common/node_util.h
  4. +1
    -1
      mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc

+ 1
- 0
mindspore/lite/src/common/anf_importer/import_from_protobuf.cc View File

@@ -1168,6 +1168,7 @@ int AnfImporterFromProtobuf::Import() {
const onnx::GraphProto &graphBuild = onnx_model_->graph();
if (!BuildFuncGraph(dstGraph, graphBuild)) {
MS_LOG(ERROR) << "Build funcgraph failed!";
func_graph_ = nullptr;
return RET_ERROR;
}
func_graph_ = dstGraph;


+ 1
- 1
mindspore/lite/tools/common/node_util.cc View File

@@ -96,7 +96,7 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_Pooling, schema::PrimitiveType_Resize,
schema::PrimitiveType_BatchNorm};
schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm};

static const std::vector<schema::PrimitiveType> fp32FullOpList = {
schema::PrimitiveType_Concat, schema::PrimitiveType_Add,


+ 1
- 1
mindspore/lite/tools/common/node_util.h View File

@@ -234,7 +234,7 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kCKHW2KHWC) {
p2Buff =
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterK) + (w * filterC) + (c));
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
} else {
p2Buff =
buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));


+ 1
- 1
mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc View File

@@ -351,7 +351,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
// todo(00445839): consider varible weight condition
}
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be CKHW
if (graphNode->subGraph->fmkType == converter::FmkType_MS) {
if (fmkType == converter::FmkType_MS) {
weightTensor->format = schema::Format_CKHW;
}
if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms


Loading…
Cancel
Save