From: @zhengjun10 Reviewed-by: @hangangqiang Signed-off-by:tags/v1.2.0-rc1
| @@ -264,6 +264,7 @@ union PrimitiveType { | |||
| If, | |||
| GeLU, | |||
| Gru, | |||
| NonZero, | |||
| } | |||
| enum QuantType: int { | |||
| @@ -236,7 +236,8 @@ union PrimitiveType { | |||
| LpNormalization, | |||
| DropoutGrad, | |||
| MaximumGrad, | |||
| MinimumGrad | |||
| MinimumGrad, | |||
| NonZero, | |||
| } | |||
| enum QuantType: int { | |||
| @@ -1241,3 +1241,6 @@ table Merge { | |||
| table GeLU { | |||
| approximate : bool = false; | |||
| } | |||
| table NonZero { | |||
| } | |||
| @@ -1143,3 +1143,6 @@ table LpNormalization { | |||
| axis : int; | |||
| p : int; | |||
| } | |||
| table NonZero { | |||
| } | |||
| @@ -0,0 +1,124 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/nonzero.h" | |||
| #include <algorithm> | |||
| #include "include/errorcode.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/tensor.h" | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int NonZero::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_NonZero; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_NonZero) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| this->primitive_->value.value = new (std::nothrow) schema::NonZeroT(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| PopulaterQuantParam(prim, inputs); | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int NonZero::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_NonZero(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_NonZero return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateNonZero(*fbb); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_NonZero, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *NonZeroCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<NonZero>(primitive); } | |||
| Registry NonZeroRegistry(schema::PrimitiveType_NonZero, NonZeroCreator); | |||
| #endif | |||
| template <typename T> | |||
| void CalShape(const T *data, const std::vector<Tensor *> &inputs, std::vector<int> *out_shape) { | |||
| int input_count = inputs[0]->ElementsNum(); | |||
| int input_dim_size = inputs[0]->shape().empty() ? 1 : inputs[0]->shape().size(); | |||
| (*out_shape)[0] = input_dim_size; | |||
| int nonzero_size = 0; | |||
| for (int i = 0; i < input_count; i++) { | |||
| if (static_cast<int>(data[i]) != 0) { | |||
| nonzero_size++; | |||
| } | |||
| } | |||
| if (nonzero_size == 0) { | |||
| *out_shape = {}; | |||
| } else { | |||
| (*out_shape)[1] = nonzero_size / input_dim_size; | |||
| } | |||
| } | |||
| int NonZero::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| MS_ASSERT(this->primitive_ != nullptr); | |||
| MS_ASSERT(inputs_.size() == 1); | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_data_type(input->data_type()); | |||
| output->set_format(input->format()); | |||
| if (!infer_flag()) { | |||
| return RET_INFER_INVALID; | |||
| } | |||
| std::vector<int> out_shape; | |||
| if (inputs_.size() == kSingleNum) { | |||
| auto input_tensor = inputs_.at(0); | |||
| if (input_tensor->data_c() == nullptr) { | |||
| MS_LOG(INFO) << "Do infer shape in runtime."; | |||
| return RET_INFER_INVALID; | |||
| } | |||
| switch (input_tensor->data_type()) { | |||
| case kNumberTypeFloat: { | |||
| auto data = reinterpret_cast<float *>(input_tensor->MutableData()); | |||
| CalShape<float>(data, inputs_, &out_shape); | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "NonZero weight tensor has unsupported dataType: " << input_tensor->data_type(); | |||
| return RET_INFER_ERR; | |||
| } | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "inputs tensor size invalid."; | |||
| return RET_INFER_ERR; | |||
| } | |||
| output->set_shape(out_shape); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_OPS_NONZERO_H_ | |||
| #define MINDSPORE_LITE_SRC_OPS_NONZERO_H_ | |||
| #include <cmath> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class NonZero : public PrimitiveC { | |||
| public: | |||
| NonZero() = default; | |||
| ~NonZero() = default; | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(NonZero, PrimitiveC); | |||
| explicit NonZero(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_OPS_NONZERO_H_ | |||
| @@ -0,0 +1,105 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp32/nonzero_fp32.h" | |||
| #include "include/errorcode.h" | |||
| #include "nnacl/op_base.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/tensor.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_NonZero; | |||
| namespace mindspore::kernel { | |||
| int NonZeroCPUKernel::Init() { | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| return ReSize(); | |||
| } | |||
| int NonZeroCPUKernel::ReSize() { return RET_OK; } | |||
| int NonZeroCPUKernel::Run() { | |||
| auto in_tensor = in_tensors_.front(); | |||
| auto out_tensor = out_tensors_.front(); | |||
| auto input_data = reinterpret_cast<float *>(in_tensor->MutableData()); | |||
| auto output_data = reinterpret_cast<int *>(out_tensor->MutableData()); | |||
| auto input_dim_size = in_tensor->shape().size(); | |||
| if (out_tensor->shape().size() != 2) { | |||
| MS_LOG(ERROR) << "out tensor shape size must be equal to 2!"; | |||
| return RET_ERROR; | |||
| } | |||
| auto non_zero_nums = out_tensor->shape()[1]; | |||
| int non_zero_count = 0; | |||
| std::vector coordiate_values(in_tensor->shape().size(), 0); | |||
| for (int i = 0; i < in_tensor->ElementsNum(); i += 1) { | |||
| if (input_data[i] != 0) { | |||
| for (size_t j = 0; j < input_dim_size; j++) { | |||
| output_data[non_zero_count + j * non_zero_nums] = coordiate_values[j]; | |||
| } | |||
| non_zero_count++; | |||
| } | |||
| for (int idx = input_dim_size - 1; idx >= 0; --idx) { | |||
| if (coordiate_values[idx] != in_tensor->shape()[idx] - 1) { | |||
| coordiate_values[idx] = coordiate_values[idx] + 1; | |||
| break; | |||
| } | |||
| coordiate_values[idx] = 0; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuNonZeroFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| if (opParameter == nullptr) { | |||
| MS_LOG(ERROR) << "Input opParameter is nullptr!"; | |||
| return nullptr; | |||
| } | |||
| if (ctx == nullptr) { | |||
| MS_LOG(ERROR) << "Input context is nullptr!"; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| if (ctx->thread_num_ == 0) { | |||
| MS_LOG(ERROR) << "context thread num is 0!"; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| auto *kernel = new (std::nothrow) NonZeroCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new NonZeroCPUKernel fail!"; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NonZero, CpuNonZeroFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NONZERO_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NONZERO_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| namespace mindspore::kernel { | |||
| class NonZeroCPUKernel : public LiteKernel { | |||
| public: | |||
| NonZeroCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~NonZeroCPUKernel() = default; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| protected: | |||
| int thread_count_ = 1; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NONZERO_H_ | |||
| @@ -100,7 +100,8 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap | |||
| } | |||
| } | |||
| if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF) { | |||
| if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF || | |||
| config->fmk == lite::converter::FmkType_ONNX) { | |||
| graph_pm->AddPass(std::make_shared<opt::WhilePass>()); | |||
| graph_pm->AddPass(std::make_shared<opt::IfPass>()); | |||
| } | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/onnx/onnx_if_parser.h" | |||
| #include <memory> | |||
| #include "tools/converter/parser/onnx/onnx_model_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| lite::PrimitiveC *OnnxIfParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx IfParser"; | |||
| auto attr = std::make_unique<schema::IfT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return nullptr; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_If; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxIfParser("If", new OnnxIfParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IF_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IF_PARSER_H | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class OnnxIfParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxIfParser() : OnnxNodeParser("If") {} | |||
| ~OnnxIfParser() override = default; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IF_PARSER_H | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/onnx/onnx_loop_parser.h" | |||
| #include <memory> | |||
| #include "tools/converter/parser/onnx/onnx_model_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| lite::PrimitiveC *OnnxLoopParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx LoopParser"; | |||
| auto attr = std::make_unique<schema::WhileT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return nullptr; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_While; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxLoopParser("Loop", new OnnxLoopParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LOOP_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LOOP_PARSER_H | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class OnnxLoopParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxLoopParser() : OnnxNodeParser("Loop") {} | |||
| ~OnnxLoopParser() override = default; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LOOP_PARSER_H | |||
| @@ -54,29 +54,60 @@ class OnnxModelParser : public ModelParser { | |||
| private: | |||
| STATUS InitOriginModel(const std::string &model_file); | |||
| STATUS ConvertNodes(); | |||
| STATUS ConvertConstTensors(); | |||
| STATUS ConvertGraphInputs(); | |||
| STATUS ConvertGraphOutputs(); | |||
| STATUS BuildReturnNode(const std::vector<AnfNodePtr> &return_inputs); | |||
| STATUS ConvertNodes(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, std::vector<AnfNodePtr> *graph_inputs, | |||
| const std::string &root_node_name); | |||
| STATUS ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | |||
| std::vector<AnfNodePtr> *graph_inputs, const std::string &root_node_name); | |||
| STATUS ConvertConstTensors(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map); | |||
| STATUS ConvertGraphInputs(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr, | |||
| std::unordered_map<std::string, AnfNodePtr> *nodes_map); | |||
| STATUS ConvertGraphOutputs(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr, | |||
| const std::unordered_map<std::string, AnfNodePtr> &anf_nodes_map); | |||
| STATUS BuildReturnNode(const FuncGraphPtr &func_graph_ptr, const std::vector<AnfNodePtr> &return_inputs); | |||
| STATUS BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::TensorProto &tensor); | |||
| STATUS BuildParameterNodeForQuantParam(void *data, const std::string &name, TypeId type); | |||
| STATUS BuildCNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c); | |||
| STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const CNodePtr &cnode); | |||
| STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c); | |||
| STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c); | |||
| STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c, const std::string &name); | |||
| STATUS BuildCNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, std::vector<AnfNodePtr> *graph_inputs, | |||
| lite::PrimitiveC *primitive_c, std::string loop_name); | |||
| STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const CNodePtr &cnode); | |||
| STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | |||
| lite::PrimitiveC *primitive_c); | |||
| STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, lite::PrimitiveC *primitive_c); | |||
| STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, lite::PrimitiveC *primitive_c, | |||
| const std::string &name); | |||
| STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c); | |||
| STATUS ParseQuantParam(const onnx::NodeProto &onnx_node); | |||
| STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params); | |||
| STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<QuantParamT> *quant_params); | |||
| STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not); | |||
| bool IsSpecialOnnxNode(const onnx::NodeProto &onnx_node); | |||
| STATUS ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | |||
| const std::string &root_node_name); | |||
| STATUS ConvertIfOnnxNode(const onnx::NodeProto &onnx_node, std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | |||
| const std::string &root_node_name); | |||
| STATUS AddTensorArrayEdge(const FuncGraphPtr &anf_graph, std::vector<AnfNodePtr> *return_new_inputs, | |||
| const std::string &loop_node_name, std::vector<AnfNodePtr> *body_graph_inputs, | |||
| int act_output_num); | |||
| STATUS BuildCondGraph(const FuncGraphPtr &cond_graph, const AnfNodePtr &root_while_node, int inputs_num, | |||
| const std::string &cond_graph_name); | |||
| STATUS ConvertIfSubgraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph, | |||
| const std::string &subgrah_name, const std::string &if_node_name, | |||
| const std::string &root_node_name); | |||
| onnx::ModelProto onnx_model_; | |||
| onnx::GraphProto onnx_graph_; | |||
| std::unordered_map<std::string, AnfNodePtr> nodes_; | |||
| FuncGraphPtr func_graph_ptr_ = nullptr; | |||
| onnx::GraphProto onnx_root_graph_; | |||
| std::vector<FuncGraphPtr> all_subgraphs_; | |||
| std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_; | |||
| std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_; | |||
| std::unordered_map<std::string, std::string> child_root_map_; // for nest control flow node | |||
| FuncGraphPtr anf_root_graph_ = nullptr; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/onnx/onnx_nonzero_parser.h" | |||
| #include <memory> | |||
| #include "tools/converter/parser/onnx/onnx_model_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| lite::PrimitiveC *OnnxNonZeroParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx NonZeroParser"; | |||
| auto attr = std::make_unique<schema::NonZeroT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return nullptr; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_NonZero; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxNonZeroParser("NonZero", new OnnxNonZeroParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NONZERO_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NONZERO_PARSER_H | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class OnnxNonZeroParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxNonZeroParser() : OnnxNodeParser("NonZero") {} | |||
| ~OnnxNonZeroParser() override = default; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NONZERO_PARSER_H | |||
| @@ -96,6 +96,19 @@ bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) { | |||
| status = ReplaceIdentity(node, manager); | |||
| } else if (type == schema::PrimitiveType_TupleGetItem) { | |||
| status = ReplaceTupleGetItem(node, manager); | |||
| } else if (type == schema::PrimitiveType_If || type == schema::PrimitiveType_While) { | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| (void)Run(sub_func_graph); | |||
| sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(2)); | |||
| if (sub_func_graph == nullptr) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| return false; | |||
| } | |||
| (void)Run(sub_func_graph); | |||
| } | |||
| if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "remove identity pass is failed."; | |||
| @@ -296,6 +296,45 @@ STATUS OnnxInputAdjustOpPass::AdjustStridedSlice(const FuncGraphPtr &func_graph, | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS OnnxInputAdjustOpPass::AdjustResize(const CNodePtr &cnode) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto node = cnode->input(0); | |||
| MS_ASSERT(value_node != nullptr); | |||
| auto value_node = node->cast<ValueNodePtr>(); | |||
| if (value_node == nullptr) { | |||
| MS_LOG(ERROR) << "cnode input0 is not a valuenode."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| MS_ASSERT(value_node->value() != nullptr); | |||
| auto primitive_c = value_node->value()->cast<PrimitiveCPtr>(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "cnode has no primitive_c."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto primitive = primitive_c->primitiveT(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "cnode has no schema::primitive."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (primitive->value.type != schema::PrimitiveType_Resize) { | |||
| MS_LOG(DEBUG) << "cnode is not cast node."; | |||
| return RET_OK; | |||
| } | |||
| auto value = primitive->value.value; | |||
| if (value == nullptr) { | |||
| MS_LOG(ERROR) << "value is nullptr."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto attr = reinterpret_cast<schema::ResizeT *>(value); | |||
| if (cnode->inputs().size() > 3 && | |||
| attr->coordinateTransformMode == schema::CoordinateTransformMode_TF_CROP_AND_RESIZE) { | |||
| auto new_resize_inputs = cnode->inputs(); | |||
| new_resize_inputs.erase(new_resize_inputs.begin() + 1); | |||
| cnode->set_inputs(new_resize_inputs); | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS OnnxInputAdjustOpPass::AdjustConvOrDeConv(const CNodePtr &cnode) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| if (!CheckInputs(cnode)) { | |||
| @@ -40,6 +40,7 @@ class OnnxInputAdjustOpPass : public Pass { | |||
| STATUS AdjustConvOrDeConv(const CNodePtr &cnode); | |||
| STATUS AdjustTile(const CNodePtr &cnode); | |||
| STATUS AdjustCast(const CNodePtr &cnode); | |||
| STATUS AdjustResize(const CNodePtr &cnode); | |||
| STATUS ReplaceConstant(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS ReplaceTransposeWithGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||