From c9db2ed81fe67df8b151d861bf3753134b5a4fba Mon Sep 17 00:00:00 2001 From: xuanyue Date: Wed, 16 Sep 2020 17:24:49 +0800 Subject: [PATCH] fix onnx pool and tflite pad --- mindspore/lite/src/runtime/kernel/opencl/kernel/cast.cc | 2 +- .../converter/legacy_optimizer/graph/format_trans_pass.cc | 6 ++---- .../lite/tools/converter/parser/onnx/onnx_pool_parser.cc | 6 +++--- .../lite/tools/converter/parser/tflite/tflite_pad_parser.cc | 6 ++++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/cast.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/cast.cc index da57b7f250..b745f6c5ff 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/cast.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/cast.cc @@ -131,7 +131,7 @@ int CastOpenCLKernel::Run() { kernel::LiteKernel *OpenCLCastKernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, - const lite::Context *ctx, const kernel::KernelKey &desc, + const lite::InnerContext *ctx, const kernel::KernelKey &desc, const mindspore::lite::PrimitiveC *primitive) { auto *kernel = new (std::nothrow) CastOpenCLKernel(opParameter, inputs, outputs); if (kernel == nullptr) { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc index 624175c97b..e2bafd4adb 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc @@ -67,9 +67,8 @@ STATUS FormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) { } for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - auto &node = *iter; - for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) { - if (node->inputIndex.at(inputIndexIdx) == inputIdx) { + for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) { + if ((*iter)->inputIndex.at(inputIndexIdx) == inputIdx) { STATUS status = RET_OK; iter = InsertFormatTransNode(graph, iter, kBefore, inputIndexIdx, kNHWC2NCHW, &status); if (status != RET_OK) { @@ -89,7 +88,6 @@ STATUS FormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) { graphInTensor->dims = {oldDims[NCHW_N], oldDims[NCHW_H], oldDims[NCHW_W], oldDims[NCHW_C]}; transed = true; } - break; } } } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc index 46ad0aac12..36ccd735cb 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc @@ -83,9 +83,9 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod if (onnx_node_attr.ints_size() == 4) { attr->padMode = schema::PadMode_CAFFE; attr->padUp = static_cast(onnx_node_attr.ints(0)); - attr->padDown = static_cast(onnx_node_attr.ints(1)); - attr->padLeft = static_cast(onnx_node_attr.ints(0)); - attr->padRight = static_cast(onnx_node_attr.ints(1)); + attr->padDown = static_cast(onnx_node_attr.ints(2)); + attr->padLeft = static_cast(onnx_node_attr.ints(1)); + attr->padRight = static_cast(onnx_node_attr.ints(3)); } } if (attribute_name == "ceil_mode") { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc index 3e50c489d4..7cc8caccb2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc @@ -74,8 +74,6 @@ STATUS TflitePadParser::Parse(const std::unique_ptr &tflite_o MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support"; return RET_INVALID_OP_ATTR; } - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); } else { MS_LOG(ERROR) << "this pad:" << node_name << " hasn't been supported"; return RET_NOT_SUPPORT; @@ -86,6 +84,10 @@ STATUS TflitePadParser::Parse(const std::unique_ptr &tflite_o AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format::Format_NHWC); + if (std::strcmp(node_name, "MirrorPad") == 0) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), + tflite_tensors.size(), schema::Format::Format_NHWC); + } AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format::Format_NHWC); return RET_OK;