Browse Source

!7121 [MSLITE] Fix bug of onxx cast parser.

Merge pull request !7121 from wangshaocong/bugfix_master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
79b974eb82
6 changed files with 12 additions and 5 deletions
  1. +1
    -1
      mindspore/lite/src/ops/cast.cc
  2. +5
    -0
      mindspore/lite/src/ops/primitive_c.cc
  3. +3
    -1
      mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc
  4. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h
  5. +1
    -1
      mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc
  6. +0
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.h

+ 1
- 1
mindspore/lite/src/ops/cast.cc View File

@@ -95,7 +95,7 @@ int Cast::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
return RET_OK;
}

if (input->data_type() != GetSrcT()) {
if (GetSrcT() != 0 && input->data_type() != GetSrcT()) {
MS_LOG(ERROR) << "input dataType is error";
return RET_INPUT_TENSOR_ERROR;
}


+ 5
- 0
mindspore/lite/src/ops/primitive_c.cc View File

@@ -131,6 +131,7 @@
#include "src/ops/custom_predict.h"
#include "src/ops/custom_normalize.h"
#include "src/ops/custom_extract_features.h"
#include "src/ops/upsample.h"
#ifdef PRIMITIVE_WRITEABLE
#include "tools/converter/quantizer/quantize_util.h"
#endif
@@ -692,6 +693,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new CustomNormalize(primitive);
case schema::PrimitiveType_CustomExtractFeatures:
return new CustomExtractFeatures(primitive);
case schema::PrimitiveType_Upsample:
return new Upsample(primitive);

#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:
@@ -960,6 +963,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) {
return NewPrimitiveC<CustomNormalize>(primitive);
case schema::PrimitiveType_CustomExtractFeatures:
return NewPrimitiveC<CustomExtractFeatures>(primitive);
case schema::PrimitiveType_Upsample:
return NewPrimitiveC<Upsample>(primitive);

#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:


+ 3
- 1
mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc View File

@@ -15,6 +15,7 @@
*/

#include "tools/converter/parser/onnx/onnx_cast_parser.h"
#include "tools/converter/parser/onnx/onnx_model_parser.h"
#include <memory>

namespace mindspore {
@@ -40,7 +41,8 @@ STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "to") {
attr->dstT = static_cast<int32_t>(onnx_node_attr.i());
attr->dstT = static_cast<int32_t>(
OnnxModelParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_node_attr.i())));
}
}



+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h View File

@@ -43,9 +43,9 @@ class OnnxModelParser : public ModelParser {
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;

private:
TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);

private:
std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value);

STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache);


mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.cc → mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc View File

@@ -15,7 +15,7 @@
*/

#include <memory>
#include "tools/converter/parser/onnx/onnx_unsample_parser.h"
#include "tools/converter/parser/onnx/onnx_upsample_parser.h"

namespace mindspore {
namespace lite {

mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h → mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.h View File


Loading…
Cancel
Save