Merge pull request !7121 from wangshaocong/bugfix_mastertags/v1.1.0
| @@ -95,7 +95,7 @@ int Cast::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| if (input->data_type() != GetSrcT()) { | |||||
| if (GetSrcT() != 0 && input->data_type() != GetSrcT()) { | |||||
| MS_LOG(ERROR) << "input dataType is error"; | MS_LOG(ERROR) << "input dataType is error"; | ||||
| return RET_INPUT_TENSOR_ERROR; | return RET_INPUT_TENSOR_ERROR; | ||||
| } | } | ||||
| @@ -131,6 +131,7 @@ | |||||
| #include "src/ops/custom_predict.h" | #include "src/ops/custom_predict.h" | ||||
| #include "src/ops/custom_normalize.h" | #include "src/ops/custom_normalize.h" | ||||
| #include "src/ops/custom_extract_features.h" | #include "src/ops/custom_extract_features.h" | ||||
| #include "src/ops/upsample.h" | |||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| #include "tools/converter/quantizer/quantize_util.h" | #include "tools/converter/quantizer/quantize_util.h" | ||||
| #endif | #endif | ||||
| @@ -692,6 +693,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new CustomNormalize(primitive); | return new CustomNormalize(primitive); | ||||
| case schema::PrimitiveType_CustomExtractFeatures: | case schema::PrimitiveType_CustomExtractFeatures: | ||||
| return new CustomExtractFeatures(primitive); | return new CustomExtractFeatures(primitive); | ||||
| case schema::PrimitiveType_Upsample: | |||||
| return new Upsample(primitive); | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| case schema::PrimitiveType_ActivationGrad: | case schema::PrimitiveType_ActivationGrad: | ||||
| @@ -960,6 +963,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) { | |||||
| return NewPrimitiveC<CustomNormalize>(primitive); | return NewPrimitiveC<CustomNormalize>(primitive); | ||||
| case schema::PrimitiveType_CustomExtractFeatures: | case schema::PrimitiveType_CustomExtractFeatures: | ||||
| return NewPrimitiveC<CustomExtractFeatures>(primitive); | return NewPrimitiveC<CustomExtractFeatures>(primitive); | ||||
| case schema::PrimitiveType_Upsample: | |||||
| return NewPrimitiveC<Upsample>(primitive); | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| case schema::PrimitiveType_ActivationGrad: | case schema::PrimitiveType_ActivationGrad: | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "tools/converter/parser/onnx/onnx_cast_parser.h" | #include "tools/converter/parser/onnx/onnx_cast_parser.h" | ||||
| #include "tools/converter/parser/onnx/onnx_model_parser.h" | |||||
| #include <memory> | #include <memory> | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -40,7 +41,8 @@ STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| const auto &attribute_name = onnx_node_attr.name(); | const auto &attribute_name = onnx_node_attr.name(); | ||||
| if (attribute_name == "to") { | if (attribute_name == "to") { | ||||
| attr->dstT = static_cast<int32_t>(onnx_node_attr.i()); | |||||
| attr->dstT = static_cast<int32_t>( | |||||
| OnnxModelParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_node_attr.i()))); | |||||
| } | } | ||||
| } | } | ||||
| @@ -43,9 +43,9 @@ class OnnxModelParser : public ModelParser { | |||||
| schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, | schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, | ||||
| const QuantType &quantType = QuantType_QUANT_NONE) override; | const QuantType &quantType = QuantType_QUANT_NONE) override; | ||||
| private: | |||||
| TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | |||||
| static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | |||||
| private: | |||||
| std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); | std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); | ||||
| STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache); | STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache); | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #include <memory> | #include <memory> | ||||
| #include "tools/converter/parser/onnx/onnx_unsample_parser.h" | |||||
| #include "tools/converter/parser/onnx/onnx_upsample_parser.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||