Browse Source

deconv2d add outputPadding

tags/v1.2.0-rc1
yeyunpeng 4 years ago
parent
commit
a7bae1413d
7 changed files with 223 additions and 108 deletions
  1. +2
    -0
      mindspore/lite/schema/ops.fbs
  2. +6
    -0
      mindspore/lite/src/ops/deconv2d.cc
  3. +2
    -1
      mindspore/lite/src/ops/deconv2d.h
  4. +164
    -90
      mindspore/lite/tools/converter/anf_transform.cc
  5. +19
    -0
      mindspore/lite/tools/converter/anf_transform.h
  6. +28
    -17
      mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc
  7. +2
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h

+ 2
- 0
mindspore/lite/schema/ops.fbs View File

@@ -480,6 +480,8 @@ table DeConv2D {
dilateH: int;
hasBias: bool = false; // DEPRECATED
activationType: ActivationType = 0;
outputPaddingW: int;
outputPaddingH: int;
}

table DeConv2DGradFilter {


+ 6
- 0
mindspore/lite/src/ops/deconv2d.cc View File

@@ -47,6 +47,8 @@ int DeConv2D::GetPadRight() const { return this->primitive_->value.AsDeConv2D()-
int DeConv2D::GetDilateW() const { return this->primitive_->value.AsDeConv2D()->dilateW; }
int DeConv2D::GetDilateH() const { return this->primitive_->value.AsDeConv2D()->dilateH; }
int DeConv2D::GetActivationType() const { return this->primitive_->value.AsDeConv2D()->activationType; }
int DeConv2D::GetOutputPaddingW() const { return this->primitive_->value.AsDeConv2D()->outputPaddingW; }
int DeConv2D::GetOutputPaddingH() const { return this->primitive_->value.AsDeConv2D()->outputPaddingH; }

void DeConv2D::SetFormat(int format) { this->primitive_->value.AsDeConv2D()->format = (schema::Format)format; }
void DeConv2D::SetGroup(int group) { this->primitive_->value.AsDeConv2D()->group = group; }
@@ -295,6 +297,8 @@ int DeConv2D::GetPadRight() const { return this->primitive_->value_as_DeConv2D()
int DeConv2D::GetDilateW() const { return this->primitive_->value_as_DeConv2D()->dilateW(); }
int DeConv2D::GetDilateH() const { return this->primitive_->value_as_DeConv2D()->dilateH(); }
int DeConv2D::GetActivationType() const { return this->primitive_->value_as_DeConv2D()->activationType(); }
int DeConv2D::GetOutputPaddingW() const { return this->primitive_->value_as_DeConv2D()->outputPaddingW(); }
int DeConv2D::GetOutputPaddingH() const { return this->primitive_->value_as_DeConv2D()->outputPaddingH(); }

PrimitiveC *DeConv2DCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<DeConv2D>(primitive);
@@ -347,6 +351,8 @@ int DeConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::
MS_LOG(ERROR) << "unsupported pad mode for deconv";
return RET_ERROR;
}
output_h += GetOutputPaddingH();
output_w += GetOutputPaddingW();
std::vector<int> out_shape = {output_n, output_h, output_w, output_c};
output->set_shape(out_shape);



+ 2
- 1
mindspore/lite/src/ops/deconv2d.h View File

@@ -71,7 +71,8 @@ class DeConv2D : public PrimitiveC {
int GetDilateW() const;
int GetDilateH() const;
int GetActivationType() const;

int GetOutputPaddingW() const;
int GetOutputPaddingH() const;
int PadUp() const { return this->pad_u_; }
int PadDown() const { return this->pad_d_; }
int PadLeft() const { return this->pad_l_; }


+ 164
- 90
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -59,62 +59,8 @@ AnfTransform::AnfTransform() = default;

AnfTransform::~AnfTransform() = default;

FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) {
MS_ASSERT(nullptr != old_graph);
if (config == nullptr) {
MS_LOG(ERROR) << "config should be specified";
return nullptr;
}
if (old_graph->has_flag("HasTransformed")) {
old_graph->set_flag("HasTransformed", false);
return old_graph;
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config) {
auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);

if (config->fmk == converter::FmkType_MS) {
auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>();
mindir_adjust_pass->SetFmkType(config->fmk);
mindir_adjust_pass->SetQuantType(config->quantType);
if (!mindir_adjust_pass->Run(old_graph)) {
MS_LOG(ERROR) << "mindir adjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
auto mindir_inputs_adjust_pass = std::make_shared<opt::MindirInputAdjustOpPass>();
if (!mindir_inputs_adjust_pass->Run(old_graph)) {
MS_LOG(ERROR) << "mindir inputs adjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
}

// onnx pre adjustment
if (config->fmk == converter::FmkType_ONNX) {
auto onnx_adjust_pass = std::make_shared<opt::OnnxInputAdjustOpPass>();
if (!onnx_adjust_pass->Run(old_graph)) {
MS_LOG(ERROR) << "onnx adjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
}

if (config->fmk == lite::converter::FmkType_TF) {
auto functionalize_control_op_pass = std::make_shared<opt::FunctionalizeControlOpPass>();
if (!functionalize_control_op_pass->Run(old_graph)) {
MS_LOG(ERROR) << "functionalize control op pass failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
}

if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF ||
config->fmk == lite::converter::FmkType_ONNX) {
graph_pm->AddPass(std::make_shared<opt::WhilePass>());
graph_pm->AddPass(std::make_shared<opt::IfPass>());
}

// for now - training is not supporting fuse operations
if (!config->trainModel) {
@@ -137,26 +83,11 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>());
fusion_pm->AddPass(std::make_shared<opt::BiDirectionTfGruCellFusion>());
}
auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>();
weight_format_hardcode_pass->SetFmkType(config->fmk);
weight_format_hardcode_pass->SetQuantType(config->quantType);
graph_pm->AddPass(weight_format_hardcode_pass);
auto weight_format_transform_pass = std::make_shared<opt::WeightFormatTransformPass>();
weight_format_transform_pass->SetFmkType(config->fmk);
weight_format_transform_pass->SetQuantType(config->quantType);
graph_pm->AddPass(weight_format_transform_pass);
auto infershape_pass = std::make_shared<opt::InferShapePass>();
infershape_pass->SetFmkType(config->fmk);
graph_pm->AddPass(infershape_pass);
auto slice_prepose_pass = std::make_shared<opt::SlicePreposePass>();
slice_prepose_pass->SetFmkType(config->fmk);
graph_pm->AddPass(slice_prepose_pass);

if (config->fmk == lite::converter::FmkType_MS) {
auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>();
if (remove_unused_cast_pass == nullptr) {
MS_LOG(ERROR) << "RemoveUnusedCastOpPass should be specified";
return nullptr;
return RET_ERROR;
}
remove_unused_cast_pass->SetFmkType(config->fmk);
fusion_pm->AddPass(remove_unused_cast_pass);
@@ -165,11 +96,55 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
auto remove_unused_transpose_pass = std::make_shared<opt::RemoveUnusedTransposeOpPass>();
if (remove_unused_transpose_pass == nullptr) {
MS_LOG(ERROR) << "RemoveUnusedTransposeOpPass should be specified";
return nullptr;
return RET_ERROR;
}
remove_unused_transpose_pass->SetFmkType(config->fmk);
fusion_pm->AddPass(remove_unused_transpose_pass);
}
fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>());
optimizer->AddPassManager(fusion_pm);
return RET_OK;
}

int AnfTransform::AddGraphPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config) {
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF ||
config->fmk == lite::converter::FmkType_ONNX) {
graph_pm->AddPass(std::make_shared<opt::WhilePass>());
graph_pm->AddPass(std::make_shared<opt::IfPass>());
}
auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>();
weight_format_hardcode_pass->SetFmkType(config->fmk);
weight_format_hardcode_pass->SetQuantType(config->quantType);
graph_pm->AddPass(weight_format_hardcode_pass);
auto weight_format_transform_pass = std::make_shared<opt::WeightFormatTransformPass>();
weight_format_transform_pass->SetFmkType(config->fmk);
weight_format_transform_pass->SetQuantType(config->quantType);
graph_pm->AddPass(weight_format_transform_pass);
auto infershape_pass = std::make_shared<opt::InferShapePass>();
infershape_pass->SetFmkType(config->fmk);
graph_pm->AddPass(infershape_pass);
auto slice_prepose_pass = std::make_shared<opt::SlicePreposePass>();
slice_prepose_pass->SetFmkType(config->fmk);
graph_pm->AddPass(slice_prepose_pass);
optimizer->AddPassManager(graph_pm);
return RET_OK;
}

int AnfTransform::AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer,
const converter::Flags *config) {
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
if (config->fmk == lite::converter::FmkType_TFLITE) {
convert_pm->AddPass(std::make_shared<opt::GroupDepthwiseOpConvertPass>());
convert_pm->AddPass(std::make_shared<opt::TfliteInputsOrderExchangePass>());
}
optimizer->AddPassManager(convert_pm);
return RET_OK;
}

int AnfTransform::AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer,
const converter::Flags *config) {
auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false);
if (!config->trainModel) {
auto inne_context_ptr = std::make_shared<lite::InnerContext>();
@@ -179,47 +154,90 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
auto update_conv2d_param_pass = std::make_shared<opt::UpdateConv2DParamPass>();
update_conv2d_param_pass->SetFmkType(config->fmk);
const_fold_pm->AddPass(update_conv2d_param_pass);
fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>());
convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
if (config->fmk == lite::converter::FmkType_TFLITE) {
convert_pm->AddPass(std::make_shared<opt::GroupDepthwiseOpConvertPass>());
convert_pm->AddPass(std::make_shared<opt::TfliteInputsOrderExchangePass>());
}
optimizer->AddPassManager(const_fold_pm);
optimizer->AddPassManager(convert_pm);
optimizer->AddPassManager(fusion_pm);
optimizer->AddPassManager(graph_pm);
auto new_graph = optimizer->Optimize(old_graph);
if (new_graph == nullptr) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);
return nullptr;
return RET_OK;
}

int AnfTransform::RunAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
switch (config->fmk) {
case converter::FmkType_MS:
return RunMindirAdjustPass(old_graph, config);
case converter::FmkType_ONNX:
return RunOnnxAdjustPass(old_graph, config);
case converter::FmkType_TF:
return RunTFAdjustPass(old_graph, config);
default:
return RET_OK;
}
}

int AnfTransform::RunMindirAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>();
mindir_adjust_pass->SetFmkType(config->fmk);
mindir_adjust_pass->SetQuantType(config->quantType);
if (!mindir_adjust_pass->Run(old_graph)) {
MS_LOG(ERROR) << "mindir adjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
auto mindir_inputs_adjust_pass = std::make_shared<opt::MindirInputAdjustOpPass>();
if (!mindir_inputs_adjust_pass->Run(old_graph)) {
MS_LOG(ERROR) << "mindir inputs adjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
return RET_OK;
}

int AnfTransform::RunOnnxAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
// onnx pre adjustment
auto onnx_adjust_pass = std::make_shared<opt::OnnxInputAdjustOpPass>();
if (!onnx_adjust_pass->Run(old_graph)) {
MS_LOG(ERROR) << "onnx adjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
return RET_OK;
}

int AnfTransform::RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto functionalize_control_op_pass = std::make_shared<opt::FunctionalizeControlOpPass>();
if (!functionalize_control_op_pass->Run(old_graph)) {
MS_LOG(ERROR) << "functionalize control op pass failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
return RET_OK;
}

int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config,
const FuncGraphPtr &new_graph) {
// quant
if (config->quantType == schema::QuantType_PostTraining) {
if (!quant::WeightQuantizer::IsPosNum(config->bitNum)) {
MS_LOG(ERROR) << "bitNum must be valid pos num.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
return RET_ERROR;
}
this->mQuantizer =
std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, std::stoi(config->bitNum));
if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return nullptr;
return RET_ERROR;
}
} else if (config->quantType == schema::QuantType_WeightQuant) {
if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) {
MS_LOG(ERROR) << "weight quant input param error";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
return RET_ERROR;
}
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->configFile, config->quantWeightSize,
config->quantWeightChannel, config->bitNum);
if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return nullptr;
return RET_ERROR;
}
}
if (mQuantizer != nullptr) {
@@ -228,9 +246,65 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
if (status != RET_OK) {
MS_LOG(ERROR) << "Quant failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
return RET_ERROR;
}
}
return RET_OK;
}

FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) {
MS_ASSERT(nullptr != old_graph);
if (config == nullptr) {
MS_LOG(ERROR) << "config should be specified";
return nullptr;
}
if (old_graph->has_flag("HasTransformed")) {
old_graph->set_flag("HasTransformed", false);
return old_graph;
}

auto status = RunAdjustPass(old_graph, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run Adjust pass failed.";
return nullptr;
}

auto optimizer = std::make_shared<opt::GraphOptimizer>();

status = AddConstFoldPass(optimizer, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Add const fold pass failed.";
return nullptr;
}

status = AddConvertPass(optimizer, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Add convert pass failed.";
return nullptr;
}

status = AddFusionPass(optimizer, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Add fusion pass failed.";
return nullptr;
}
status = AddGraphPass(optimizer, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Add graph pass failed.";
return nullptr;
}

auto new_graph = optimizer->Optimize(old_graph);
if (new_graph == nullptr) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);
return nullptr;
}

status = DoQuantize(old_graph, config, new_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do Quantize failed.";
return nullptr;
}
return new_graph;
}



+ 19
- 0
mindspore/lite/tools/converter/anf_transform.h View File

@@ -19,6 +19,7 @@

#include <memory>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "schema/inner/model_generated.h"
#include "tools/common/storage.h"
#include "tools/converter/converter_flags.h"
@@ -39,6 +40,24 @@ class AnfTransform {
std::vector<ValueNodePtr> *vnodes);
FuncGraphPtr TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);
std::unique_ptr<quant::Quantizer> mQuantizer = nullptr;

int AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);

int AddGraphPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);

int AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);

int AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);

int RunAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);

int RunMindirAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);

int RunOnnxAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);

int RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);

int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph);
};
} // namespace lite
} // namespace mindspore


+ 28
- 17
mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc View File

@@ -53,42 +53,36 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeC
return true;
}

lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx DeConvParser";
auto attr = std::make_unique<schema::DeConv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}

int OnnxDeConvParser::ParseParameters(const onnx::NodeProto &onnx_node,
const std::unique_ptr<schema::DeConv2DT> &attr) {
attr->padMode = schema::PadMode_NOTSET;
attr->group = 1;
attr->strideW = 1;
attr->strideH = 1;
attr->dilateW = 1;
attr->dilateH = 1;

for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "group") {
attr->group = static_cast<int32_t>(onnx_node_attr.i());
} else if (onnx_node_attr.name() == "dilations") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
return nullptr;
return RET_ERROR;
}
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernels") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return nullptr;
return RET_ERROR;
}
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernel_shape") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return nullptr;
return RET_ERROR;
}
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
@@ -97,7 +91,7 @@ lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &o
} else if (onnx_node_attr.name() == "pads") {
if (onnx_node_attr.ints().size() != 4) {
MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4";
return nullptr;
return RET_ERROR;
}
attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->padLeft = static_cast<int32_t>(onnx_node_attr.ints(1));
@@ -106,7 +100,7 @@ lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &o
} else if (onnx_node_attr.name() == "strides") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
return nullptr;
return RET_ERROR;
}
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1));
@@ -115,13 +109,30 @@ lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &o
attr->format = schema::Format::Format_NHWC;
} else {
MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s().c_str();
return nullptr;
return RET_ERROR;
}
} else if (onnx_node_attr.name() == "output_padding") {
MS_LOG(ERROR) << "output_padding param hasn't been supported";
return nullptr;
attr->outputPaddingH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->outputPaddingW = static_cast<int32_t>(onnx_node_attr.ints(1));
}
}
return RET_OK;
}

lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx DeConvParser";
auto attr = std::make_unique<schema::DeConv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}

auto status = ParseParameters(onnx_node, attr);
if (status != RET_OK) {
MS_LOG(ERROR) << "Parse parameters failed.";
return nullptr;
}

const auto &onnx_conv_weight = onnx_node.input(1);
auto node_iter =


+ 2
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h View File

@@ -32,6 +32,8 @@ class OnnxDeConvParser : public OnnxNodeParser {

private:
bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::PrimitiveT *primitive);

int ParseParameters(const onnx::NodeProto &onnx_node, const std::unique_ptr<schema::DeConv2DT> &attr);
};
} // namespace lite
} // namespace mindspore


Loading…
Cancel
Save