Browse Source

!8717 [lite] fix onnx operator converter bug

From: @xu_anyue
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
5d4e95ad95
2 changed files with 24 additions and 2 deletions
  1. +1
    -1
      mindspore/lite/tools/common/graph_util.cc
  2. +23
    -1
      mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc

+ 1
- 1
mindspore/lite/tools/common/graph_util.cc View File

@@ -689,7 +689,7 @@ STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNo
MS_LOG(ERROR) << "node->primitive->value.AsConcat() is nullptr";
return RET_NULL_PTR;
}
node->primitive->value.AsConcat()->axis = axis_map[origin_axis];
node->primitive->value.AsConcat()->axis = axis_map[origin_axis < 0 ? origin_axis + 4 : origin_axis];
}
if (type == schema::PrimitiveType_Split) {
MS_ASSERT(node->primitive->value.AsSplit() != nullptr);


+ 23
- 1
mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc View File

@@ -16,6 +16,9 @@

#include "tools/converter/parser/onnx/onnx_tile_parser.h"
#include <memory>
#include <numeric>
#include <vector>
#include "tools/converter/parser/onnx/onnx_tensor_parser.h"

namespace mindspore {
namespace lite {
@@ -36,7 +39,26 @@ STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}

const auto &onnx_tile_multiple = onnx_node.input(1);
int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_tile_multiple);
if (index == -1) {
MS_LOG(ERROR) << "can not find node: " << onnx_tile_multiple;
return RET_ERROR;
}
auto tile_attr = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index];
if (tile_attr->data.data() == nullptr) {
MS_LOG(ERROR) << "power's attr pow can't be obtained.";
return RET_INVALID_OP_ATTR;
}
int element_size = std::accumulate(tile_attr->dims.begin(), tile_attr->dims.end(), 1, std::multiplies<int>());
std::vector<int> multiples;
std::vector<int> dims;
for (int i = 0; i < element_size; ++i) {
multiples.push_back(reinterpret_cast<int *>(tile_attr->data.data())[i]);
dims.push_back(i);
}
attr->multiples = multiples;
attr->dims = dims;
op->primitive->value.type = schema::PrimitiveType_Tile;
op->primitive->value.value = attr.release();
return RET_OK;


Loading…
Cancel
Save