Merge pull request !6230 from 徐安越/mastertags/v1.0.0
| @@ -32,6 +32,7 @@ constexpr int RET_PARAM_INVALID = -3; /**< Invalid parameter.*/ | |||
| constexpr int RET_NO_CHANGE = -4; /**< No change. */ | |||
| constexpr int RET_SUCCESS_EXIT = -5; /**< No error but exit. */ | |||
| constexpr int RET_MEMORY_FAILED = -6; /**< Fail to create memory. */ | |||
| constexpr int RET_NOT_SUPPORT = -7; /**< Fail to support. */ | |||
| /* Executor error code, range: [-101,-200] */ | |||
| constexpr int RET_OUT_OF_TENSOR_RANGE = -101; /**< Failed to check range. */ | |||
| @@ -53,6 +54,10 @@ constexpr int RET_FORMAT_ERR = -401; /**< Failed to checking tensor format. */ | |||
| /* InferShape error code, range: [-501,-600] */ | |||
| constexpr int RET_INFER_ERR = -501; /**< Failed to infer shape. */ | |||
| constexpr int RET_INFER_INVALID = -502; /**< Invalid infer shape before runtime. */ | |||
| /* User input param error code, range: [-601, 700]*/ | |||
| constexpr int RET_INPUT_PARAM_INVALID = -601; /**< Invalid input param by user. */ | |||
| constexpr int RET_INPUT_PARAM_LACK = -602; /**< LACK input param by user. */ | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -203,6 +203,8 @@ union PrimitiveType { | |||
| LogGrad, | |||
| BatchToSpaceND, | |||
| LshProjection, | |||
| HashtableLookup, | |||
| SkipGram, | |||
| } | |||
| enum QuantType: int { | |||
| @@ -948,3 +948,12 @@ table BlackBox { | |||
| table LshProjection { | |||
| type : LshProjectionType; | |||
| } | |||
| table HashtableLookup { | |||
| } | |||
| table SkipGram { | |||
| includeAllGrams : bool; | |||
| maxSkipSize : int; | |||
| ngramSize : int; | |||
| } | |||
| @@ -109,15 +109,15 @@ int RunConverter(int argc, const char **argv) { | |||
| std::unique_ptr<converter::Flags> flags(new (std::nothrow) converter::Flags); | |||
| if (flags == nullptr) { | |||
| MS_LOG(ERROR) << "new flags error "; | |||
| std::cout << "NEW FLAGS ERROR:" << RET_MEMORY_FAILED << std::endl; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| auto status = flags->Init(argc, argv); | |||
| if (status == RET_SUCCESS_EXIT) { | |||
| return status; | |||
| } | |||
| if (status != 0) { | |||
| MS_LOG(ERROR) << "converter::Flags Init failed: " << status; | |||
| std::cout << "CONVERTER::FLAGS INIT FAILED" << std::endl; | |||
| if (status != RET_OK) { | |||
| if (status != RET_SUCCESS_EXIT) { | |||
| MS_LOG(ERROR) << "converter::Flags Init failed: " << status; | |||
| std::cout << "CONVERTER::FLAGS INIT FAILED:" << status << std::endl; | |||
| } | |||
| return status; | |||
| } | |||
| // Load graph | |||
| @@ -148,13 +148,14 @@ int RunConverter(int argc, const char **argv) { | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported fmkType: " << flags->fmk; | |||
| return 1; | |||
| std::cout << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << std::endl; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| } | |||
| status = ReturnCode::GetSingleReturnCode()->GetReturnCode(); | |||
| if (fb_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Convert model return nullptr"; | |||
| std::cout << "CONVERT RESULT: FAILED!" << std::endl; | |||
| std::cout << "CONVERT RESULT FAILED:" << status << std::endl; | |||
| return status; | |||
| } | |||
| @@ -164,14 +165,14 @@ int RunConverter(int argc, const char **argv) { | |||
| status = storage.Save(*fb_graph, flags->outputFile); | |||
| if (status != 0) { | |||
| MS_LOG(ERROR) << "Save graph failed"; | |||
| std::cout << "SAVE GRAPH FAILED!" << std::endl; | |||
| return RET_ERROR; | |||
| std::cout << "SAVE GRAPH FAILED:" << status << std::endl; | |||
| return status; | |||
| } | |||
| delete fb_graph; | |||
| MS_LOG(INFO) << "CONVERT RESULT: SUCCESS!"; | |||
| std::cout << "CONVERT RESULT: SUCCESS!" << std::endl; | |||
| return RET_OK; | |||
| std::cout << "CONVERT RESULT SUCCESS:" << status << std::endl; | |||
| return status; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -55,7 +55,7 @@ int Flags::Init(int argc, const char **argv) { | |||
| if (err.IsSome()) { | |||
| std::cerr << err.Get(); | |||
| std::cerr << this->Usage() << std::endl; | |||
| return 1; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (this->help) { | |||
| @@ -64,21 +64,21 @@ int Flags::Init(int argc, const char **argv) { | |||
| } | |||
| if (this->modelFile.empty()) { | |||
| std::cerr << "INPUT MISSING: model file path is necessary"; | |||
| return 1; | |||
| return RET_INPUT_PARAM_LACK; | |||
| } | |||
| if (this->outputFile.empty()) { | |||
| std::cerr << "INPUT MISSING: output file path is necessary"; | |||
| return 1; | |||
| return RET_INPUT_PARAM_LACK; | |||
| } | |||
| if (this->outputFile.rfind('/') == this->outputFile.length() - 1) { | |||
| std::cerr << "INPUT ILLEGAL: outputFile must be a valid file path"; | |||
| return 1; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (this->fmkIn.empty()) { | |||
| std::cerr << "INPUT MISSING: fmk is necessary"; | |||
| return 1; | |||
| return RET_INPUT_PARAM_LACK; | |||
| } | |||
| if (this->inputInferenceTypeIn == "FLOAT") { | |||
| this->inputInferenceType = TypeId::kNumberTypeFloat; | |||
| @@ -87,7 +87,7 @@ int Flags::Init(int argc, const char **argv) { | |||
| } else { | |||
| std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s, supported inputInferenceType: FLOAT | INT8", | |||
| this->inputInferenceTypeIn.c_str(); | |||
| return 1; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (this->inferenceTypeIn == "FLOAT") { | |||
| @@ -97,7 +97,7 @@ int Flags::Init(int argc, const char **argv) { | |||
| } else { | |||
| std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8", | |||
| this->inferenceTypeIn.c_str(); | |||
| return 1; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (this->fmkIn == "CAFFE") { | |||
| @@ -110,12 +110,12 @@ int Flags::Init(int argc, const char **argv) { | |||
| this->fmk = FmkType_ONNX; | |||
| } else { | |||
| std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS|ONNX"; | |||
| return 1; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (this->fmk != FmkType_CAFFE && !weightFile.empty()) { | |||
| std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag"; | |||
| return 1; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (this->quantTypeIn == "AwareTraining") { | |||
| this->quantType = QuantType_AwareTraining; | |||
| @@ -127,7 +127,7 @@ int Flags::Init(int argc, const char **argv) { | |||
| this->quantType = QuantType_QUANT_NONE; | |||
| } else { | |||
| std::cerr << "INPUT ILLEGAL: quantType must be AwareTraining|WeightQuant|PostTraining"; | |||
| return 1; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| @@ -137,9 +137,9 @@ int Flags::Init(int argc, const char **argv) { | |||
| this->trainModel = false; | |||
| } else { | |||
| std::cerr << "INPUT ILLEGAL: trainModel must be true|false "; | |||
| return 1; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| return 0; | |||
| return RET_OK; | |||
| } | |||
| } // namespace converter | |||
| } // namespace lite | |||
| @@ -176,6 +176,9 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| MS_LOG(ERROR) << "Convert Convolution to Depthwise failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (attr->group != 1) { | |||
| MS_LOG(ERROR) << "group conv hasn't supported"; | |||
| return RET_NOT_SUPPORT; | |||
| } else { | |||
| op->primitive->value.type = schema::PrimitiveType_Conv2D; | |||
| op->primitive->value.value = attr.release(); | |||
| @@ -78,5 +78,6 @@ STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node | |||
| } | |||
| OnnxNodeRegistrar g_onnxLrnxParser("Lrn", new OnnxLrnParser()); | |||
| OnnxNodeRegistrar g_onnxLRNxParser("LRN", new OnnxLrnParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -42,19 +42,19 @@ STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &t | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsExpandDimsOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| std::vector<int> dims; | |||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, dims)) { | |||
| MS_LOG(ERROR) << "get expand_dims -> dim failed"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->dim = -1; | |||
| MS_LOG(ERROR) << "The attr dim is folded by TFLite."; | |||
| return RET_ERROR; | |||
| attr->dim = dims[0]; | |||
| op->primitive->value.type = schema::PrimitiveType_ExpandDims; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tflite/tflite_hashtable_lookup_parser.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <map> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteHashtableLookupParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||
| std::vector<schema::Format> *tensors_format, | |||
| std::map<int, int> *tensors_id_map) { | |||
| MS_LOG(DEBUG) << "parse TfliteHashtableLookupParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::HashtableLookupT> attr = std::make_unique<schema::HashtableLookupT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_HashtableLookup; | |||
| op->primitive->value.value = attr.release(); | |||
| for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { | |||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), | |||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| for (size_t i = 0; i < tflite_op->outputs.size(); ++i) { | |||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[i], tensors_id->size(), | |||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TfliteNodeRegister g_tfliteHashtableLookupParser("HashtableLookup", new TfliteHashtableLookupParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_HASHTABLE_LOOKUP_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_HASHTABLE_LOOKUP_PARSER_H | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TfliteHashtableLookupParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteHashtableLookupParser() : TfliteNodeParser("HashtableLookup") {} | |||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||
| std::map<int, int> *tensors_id_map) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_HASHTABLE_LOOKUP_PARSER_H | |||
| @@ -42,18 +42,43 @@ STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_o | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsPadOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->paddingMode = schema::PaddingMode_CONSTANT; | |||
| attr->constantValue = 0.0f; | |||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->paddings)) { | |||
| MS_LOG(ERROR) << "get pad -> paddings failed"; | |||
| return RET_ERROR; | |||
| std::vector<std::string> node_name_str; | |||
| Split(op->name, &node_name_str, "-"); | |||
| const char *node_name = node_name_str.data()->c_str(); | |||
| if (std::strcmp(node_name, "Pad") == 0) { | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsPadOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->paddingMode = schema::PaddingMode_CONSTANT; | |||
| attr->constantValue = 0.0f; | |||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->paddings)) { | |||
| MS_LOG(ERROR) << "get pad -> paddings failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (std::strcmp(node_name, "MirrorPad") == 0) { | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsMirrorPadOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| switch (tflite_attr->mode) { | |||
| case tflite::MirrorPadMode_REFLECT: | |||
| attr->paddingMode = schema::PaddingMode_REFLECT; | |||
| break; | |||
| case tflite::MirrorPadMode_SYMMETRIC: | |||
| attr->paddingMode = schema::PaddingMode_SYMMETRIC; | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support"; | |||
| return RET_INVALID_OP_ATTR; | |||
| } | |||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), | |||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||
| } else { | |||
| MS_LOG(ERROR) << "this pad:" << node_name << " hasn't been supported"; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Pad; | |||
| @@ -67,5 +92,6 @@ STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_o | |||
| } | |||
| TfliteNodeRegister g_tflitePadParser("Pad", new TflitePadParser()); | |||
| TfliteNodeRegister g_tfliteMirorPadParser("MirrorPad", new TflitePadParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tflite/tflite_skip_gram_parser.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <map> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TfliteSkipGramParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||
| MS_LOG(DEBUG) << "parse TfliteSkipGramParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::SkipGramT> attr = std::make_unique<schema::SkipGramT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| const auto &tflite_attr = tflite_op->builtin_options.AsSkipGramOptions(); | |||
| if (tflite_attr == nullptr) { | |||
| MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->includeAllGrams = tflite_attr->include_all_ngrams; | |||
| attr->maxSkipSize = tflite_attr->max_skip_size; | |||
| attr->ngramSize = tflite_attr->ngram_size; | |||
| op->primitive->value.type = schema::PrimitiveType_SkipGram; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| TfliteNodeRegister g_TfliteSkiGramParser("SKipGram", new TfliteSkipGramParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SKIP_GRAM_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SKIP_GRAM_PARSER_H | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TfliteSkipGramParser : public TfliteNodeParser { | |||
| public: | |||
| TfliteSkipGramParser() : TfliteNodeParser("SkipGram") {} | |||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||
| std::map<int, int> *tensors_id_map) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SKIP_GRAM_PARSER_H | |||
| @@ -57,7 +57,7 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{ | |||
| {tflite::BuiltinOperator_POW, "Pow"}, | |||
| {tflite::BuiltinOperator_ARG_MIN, "Argmin"}, | |||
| {tflite::BuiltinOperator_CEIL, "Ceil"}, | |||
| // {tflite::BuiltinOperator_EXPAND_DIMS, "ExpandDims"}, | |||
| {tflite::BuiltinOperator_EXPAND_DIMS, "ExpandDims"}, | |||
| {tflite::BuiltinOperator_FILL, "Fill"}, | |||
| {tflite::BuiltinOperator_DIV, "Div"}, | |||
| {tflite::BuiltinOperator_FLOOR, "flOOR"}, | |||
| @@ -117,6 +117,7 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{ | |||
| {tflite::BuiltinOperator_UNIQUE, "Unique"}, | |||
| {tflite::BuiltinOperator_UNPACK, "Unstack"}, | |||
| {tflite::BuiltinOperator_CUSTOM, "Custom"}, | |||
| {tflite::BuiltinOperator_MIRROR_PAD, "MirrorPad"}, | |||
| }; | |||
| std::map<tflite::ActivationFunctionType, schema::ActivationType> tfMsActivationFunctionMap{ | |||
| @@ -33,7 +33,7 @@ class ReturnCode { | |||
| statusCode = status; | |||
| } | |||
| } | |||
| STATUS GetReturnCode() { | |||
| STATUS GetReturnCode() const { | |||
| return statusCode; | |||
| } | |||
| private: | |||
| @@ -85,8 +85,8 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { | |||
| param_value->set_tensor_type(type_id); | |||
| param_value->set_format(tensor->GetFormat()); | |||
| if (tensor->MutableData() != nullptr) { | |||
| auto size = tensor->ElementsNum(); | |||
| auto tensor_data = new (std::nothrow) float[size]; | |||
| auto size = tensor->Size(); | |||
| auto tensor_data = new (std::nothrow) uint8_t[size]; | |||
| if (tensor_data == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_data is nullptr"; | |||
| return nullptr; | |||
| @@ -98,7 +98,7 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { | |||
| return nullptr; | |||
| } | |||
| param_value->set_tensor_addr(tensor_data); | |||
| param_value->set_tensor_size(size * sizeof(float) / sizeof(uint8_t)); | |||
| param_value->set_tensor_size(size); | |||
| } | |||
| parameter->set_default_param(param_value); | |||
| return parameter; | |||