| @@ -95,8 +95,8 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) { | |||
| MS_LOG(ERROR) << "QuantDTypeCast need quantization parameters which is not found."; | |||
| return RET_ERROR; | |||
| } | |||
| auto quant_arg = !in_tensors_.front()->GetQuantParams().empty() ? in_tensors_.front()->GetQuantParams().front() | |||
| : out_tensors_.front()->GetQuantParams().front(); | |||
| auto quant_arg = !out_tensors_.front()->GetQuantParams().empty() ? out_tensors_.front()->GetQuantParams().front() | |||
| : in_tensors_.front()->GetQuantParams().front(); | |||
| int ret; | |||
| if (uint8_ptr_ == nullptr) { | |||
| if (inverse_) { | |||
| @@ -97,10 +97,6 @@ int ResizeBaseCPUKernel::CheckInputsOuputs() { | |||
| MS_LOG(ERROR) << "Resize input num should be no more than" << kMaxInputNum << ", but got " << in_tensors_.size(); | |||
| return RET_ERROR; | |||
| } | |||
| auto input = in_tensors_.at(0); | |||
| if (input == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (out_tensors_.size() != kOutputNum) { | |||
| MS_LOG(ERROR) << "Resize output num should be " << kOutputNum << ", but got " << out_tensors_.size(); | |||
| return RET_ERROR; | |||
| @@ -300,7 +300,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Floor; | |||
| op->primitive->value.value = attr.release(); | |||
| } else if (std::strcmp(node_name, "Neg") == 0) { | |||
| } else if (std::strcmp(node_name, "NEG") == 0) { | |||
| MS_LOG(DEBUG) << "parse TfliteNegParser"; | |||
| auto attr = std::make_unique<schema::NegT>(); | |||
| if (attr == nullptr) { | |||
| @@ -424,7 +424,7 @@ TfliteNodeRegister g_TfliteLogParser("Log", new TfliteLogParser()); | |||
| TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteRoundParser()); | |||
| TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser()); | |||
| TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteFloorParser()); | |||
| TfliteNodeRegister g_tfliteNegParser("Neg", new TfliteNegParser()); | |||
| TfliteNodeRegister g_tfliteNegParser("NEG", new TfliteNegParser()); | |||
| TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser()); | |||
| TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser()); | |||
| @@ -46,7 +46,9 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &t | |||
| MS_LOG(ERROR) << "output tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 || GetTfliteDataType(in_tensor->type) == kNumberTypeUInt8) { | |||
| if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) && | |||
| (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 || | |||
| GetTfliteDataType(in_tensor->type) == kNumberTypeUInt8)) { | |||
| std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * distributed under the License is distributed on an AS | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tflite/tflite_prelu_parser.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <map> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TflitePReLUParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||
| MS_LOG(DEBUG) << "parse TflitePReLUParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::PReLUT> attr = std::make_unique<schema::PReLUT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->channelShared = true; | |||
| op->primitive->value.type = schema::PrimitiveType_PReLU; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), | |||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| TfliteNodeRegister g_tflitePReLUParser("PRELU", new TflitePReLUParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_PRELU_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_PRELU_PARSER_H | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TflitePReLUParser : public TfliteNodeParser { | |||
| public: | |||
| TflitePReLUParser() : TfliteNodeParser("PRELU") {} | |||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||
| std::map<int, int> *tensors_id_map) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_PRELU_PARSER_H | |||
| @@ -46,8 +46,9 @@ STATUS TfliteQuantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfl | |||
| MS_LOG(ERROR) << "output tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 || | |||
| GetTfliteDataType(out_tensor->type) == kNumberTypeUInt8) { | |||
| if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) && | |||
| (GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 || | |||
| GetTfliteDataType(out_tensor->type) == kNumberTypeUInt8)) { | |||
| std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||