Browse Source

!10373 [MSLITE] add onnx loop support

From: @zhengjun10
Reviewed-by: @hangangqiang
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
5602994d48
20 changed files with 1312 additions and 124 deletions
  1. +1
    -0
      mindspore/lite/schema/model.fbs
  2. +2
    -1
      mindspore/lite/schema/model_v0.fbs
  3. +3
    -0
      mindspore/lite/schema/ops.fbs
  4. +3
    -0
      mindspore/lite/schema/ops_v0.fbs
  5. +124
    -0
      mindspore/lite/src/ops/nonzero.cc
  6. +45
    -0
      mindspore/lite/src/ops/nonzero.h
  7. +105
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/nonzero_fp32.cc
  8. +41
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/nonzero_fp32.h
  9. +2
    -1
      mindspore/lite/tools/converter/anf_transform.cc
  10. +42
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_if_parser.cc
  11. +34
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_if_parser.h
  12. +42
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_loop_parser.cc
  13. +34
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_loop_parser.h
  14. +660
    -108
      mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc
  15. +45
    -14
      mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h
  16. +42
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_nonzero_parser.cc
  17. +34
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_nonzero_parser.h
  18. +13
    -0
      mindspore/lite/tools/optimizer/graph/identity_remove_pass.cc
  19. +39
    -0
      mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc
  20. +1
    -0
      mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h

+ 1
- 0
mindspore/lite/schema/model.fbs View File

@@ -264,6 +264,7 @@ union PrimitiveType {
If,
GeLU,
Gru,
NonZero,
}

enum QuantType: int {


+ 2
- 1
mindspore/lite/schema/model_v0.fbs View File

@@ -236,7 +236,8 @@ union PrimitiveType {
LpNormalization,
DropoutGrad,
MaximumGrad,
MinimumGrad
MinimumGrad,
NonZero,
}

enum QuantType: int {


+ 3
- 0
mindspore/lite/schema/ops.fbs View File

@@ -1241,3 +1241,6 @@ table Merge {
table GeLU {
approximate : bool = false;
}

table NonZero {
}

+ 3
- 0
mindspore/lite/schema/ops_v0.fbs View File

@@ -1143,3 +1143,6 @@ table LpNormalization {
axis : int;
p : int;
}

table NonZero {
}

+ 124
- 0
mindspore/lite/src/ops/nonzero.cc View File

@@ -0,0 +1,124 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/nonzero.h"
#include <algorithm>
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "src/tensor.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int NonZero::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_NonZero;
}
if (this->primitive_->value.type != schema::PrimitiveType_NonZero) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
this->primitive_->value.value = new (std::nothrow) schema::NonZeroT();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
PopulaterQuantParam(prim, inputs);
return RET_OK;
}
#else
int NonZero::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_NonZero();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_NonZero return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateNonZero(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_NonZero, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *NonZeroCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<NonZero>(primitive); }
Registry NonZeroRegistry(schema::PrimitiveType_NonZero, NonZeroCreator);
#endif
template <typename T>
void CalShape(const T *data, const std::vector<Tensor *> &inputs, std::vector<int> *out_shape) {
int input_count = inputs[0]->ElementsNum();
int input_dim_size = inputs[0]->shape().empty() ? 1 : inputs[0]->shape().size();
(*out_shape)[0] = input_dim_size;
int nonzero_size = 0;
for (int i = 0; i < input_count; i++) {
if (static_cast<int>(data[i]) != 0) {
nonzero_size++;
}
}
if (nonzero_size == 0) {
*out_shape = {};
} else {
(*out_shape)[1] = nonzero_size / input_dim_size;
}
}
int NonZero::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
MS_ASSERT(inputs_.size() == 1);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->set_format(input->format());
if (!infer_flag()) {
return RET_INFER_INVALID;
}

std::vector<int> out_shape;
if (inputs_.size() == kSingleNum) {
auto input_tensor = inputs_.at(0);
if (input_tensor->data_c() == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";
return RET_INFER_INVALID;
}
switch (input_tensor->data_type()) {
case kNumberTypeFloat: {
auto data = reinterpret_cast<float *>(input_tensor->MutableData());
CalShape<float>(data, inputs_, &out_shape);
} break;
default: {
MS_LOG(ERROR) << "NonZero weight tensor has unsupported dataType: " << input_tensor->data_type();
return RET_INFER_ERR;
}
}
} else {
MS_LOG(ERROR) << "inputs tensor size invalid.";
return RET_INFER_ERR;
}
output->set_shape(out_shape);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 45
- 0
mindspore/lite/src/ops/nonzero.h View File

@@ -0,0 +1,45 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_OPS_NONZERO_H_
#define MINDSPORE_LITE_SRC_OPS_NONZERO_H_

#include <cmath>
#include <memory>
#include <set>
#include <vector>

#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
class NonZero : public PrimitiveC {
public:
NonZero() = default;
~NonZero() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(NonZero, PrimitiveC);
explicit NonZero(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_SRC_OPS_NONZERO_H_

+ 105
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/nonzero_fp32.cc View File

@@ -0,0 +1,105 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/nonzero_fp32.h"
#include "include/errorcode.h"
#include "nnacl/op_base.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
#include "src/tensor.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_NonZero;

namespace mindspore::kernel {
int NonZeroCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int NonZeroCPUKernel::ReSize() { return RET_OK; }
int NonZeroCPUKernel::Run() {
auto in_tensor = in_tensors_.front();
auto out_tensor = out_tensors_.front();
auto input_data = reinterpret_cast<float *>(in_tensor->MutableData());
auto output_data = reinterpret_cast<int *>(out_tensor->MutableData());
auto input_dim_size = in_tensor->shape().size();
if (out_tensor->shape().size() != 2) {
MS_LOG(ERROR) << "out tensor shape size must be equal to 2!";
return RET_ERROR;
}
auto non_zero_nums = out_tensor->shape()[1];
int non_zero_count = 0;
std::vector coordiate_values(in_tensor->shape().size(), 0);
for (int i = 0; i < in_tensor->ElementsNum(); i += 1) {
if (input_data[i] != 0) {
for (size_t j = 0; j < input_dim_size; j++) {
output_data[non_zero_count + j * non_zero_nums] = coordiate_values[j];
}
non_zero_count++;
}
for (int idx = input_dim_size - 1; idx >= 0; --idx) {
if (coordiate_values[idx] != in_tensor->shape()[idx] - 1) {
coordiate_values[idx] = coordiate_values[idx] + 1;
break;
}
coordiate_values[idx] = 0;
}
}
return RET_OK;
}

kernel::LiteKernel *CpuNonZeroFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
if (ctx == nullptr) {
MS_LOG(ERROR) << "Input context is nullptr!";
free(opParameter);
return nullptr;
}
if (ctx->thread_num_ == 0) {
MS_LOG(ERROR) << "context thread num is 0!";
free(opParameter);
return nullptr;
}
auto *kernel = new (std::nothrow) NonZeroCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new NonZeroCPUKernel fail!";
free(opParameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NonZero, CpuNonZeroFp32KernelCreator)
} // namespace mindspore::kernel

+ 41
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/nonzero_fp32.h View File

@@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NONZERO_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NONZERO_H_

#include <vector>
#include "src/lite_kernel.h"

namespace mindspore::kernel {
class NonZeroCPUKernel : public LiteKernel {
public:
NonZeroCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}

~NonZeroCPUKernel() = default;

int Init() override;
int ReSize() override;
int Run() override;

protected:
int thread_count_ = 1;
};
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NONZERO_H_

+ 2
- 1
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -100,7 +100,8 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
}
}

if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF) {
if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF ||
config->fmk == lite::converter::FmkType_ONNX) {
graph_pm->AddPass(std::make_shared<opt::WhilePass>());
graph_pm->AddPass(std::make_shared<opt::IfPass>());
}


+ 42
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_if_parser.cc View File

@@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/parser/onnx/onnx_if_parser.h"
#include <memory>
#include "tools/converter/parser/onnx/onnx_model_parser.h"

namespace mindspore {
namespace lite {
lite::PrimitiveC *OnnxIfParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx IfParser";
auto attr = std::make_unique<schema::IfT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_If;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxIfParser("If", new OnnxIfParser());
} // namespace lite
} // namespace mindspore

+ 34
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_if_parser.h View File

@@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IF_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IF_PARSER_H

#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {
class OnnxIfParser : public OnnxNodeParser {
public:
OnnxIfParser() : OnnxNodeParser("If") {}
~OnnxIfParser() override = default;

lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IF_PARSER_H

+ 42
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_loop_parser.cc View File

@@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/parser/onnx/onnx_loop_parser.h"
#include <memory>
#include "tools/converter/parser/onnx/onnx_model_parser.h"

namespace mindspore {
namespace lite {
lite::PrimitiveC *OnnxLoopParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx LoopParser";
auto attr = std::make_unique<schema::WhileT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_While;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxLoopParser("Loop", new OnnxLoopParser());
} // namespace lite
} // namespace mindspore

+ 34
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_loop_parser.h View File

@@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LOOP_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LOOP_PARSER_H

#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {
class OnnxLoopParser : public OnnxNodeParser {
public:
OnnxLoopParser() : OnnxNodeParser("Loop") {}
~OnnxLoopParser() override = default;

lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LOOP_PARSER_H

+ 660
- 108
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc
File diff suppressed because it is too large
View File


+ 45
- 14
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h View File

@@ -54,29 +54,60 @@ class OnnxModelParser : public ModelParser {

private:
STATUS InitOriginModel(const std::string &model_file);
STATUS ConvertNodes();
STATUS ConvertConstTensors();
STATUS ConvertGraphInputs();
STATUS ConvertGraphOutputs();
STATUS BuildReturnNode(const std::vector<AnfNodePtr> &return_inputs);
STATUS ConvertNodes(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, std::vector<AnfNodePtr> *graph_inputs,
const std::string &root_node_name);
STATUS ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
std::vector<AnfNodePtr> *graph_inputs, const std::string &root_node_name);
STATUS ConvertConstTensors(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map);
STATUS ConvertGraphInputs(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *nodes_map);
STATUS ConvertGraphOutputs(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
const std::unordered_map<std::string, AnfNodePtr> &anf_nodes_map);
STATUS BuildReturnNode(const FuncGraphPtr &func_graph_ptr, const std::vector<AnfNodePtr> &return_inputs);
STATUS BuildParameterNode(const ParameterPtr &parameter_node, const onnx::TensorProto &tensor);
STATUS BuildParameterNodeForQuantParam(void *data, const std::string &name, TypeId type);
STATUS BuildCNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const CNodePtr &cnode);
STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c, const std::string &name);
STATUS BuildCNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, std::vector<AnfNodePtr> *graph_inputs,
lite::PrimitiveC *primitive_c, std::string loop_name);
STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const CNodePtr &cnode);
STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
lite::PrimitiveC *primitive_c);
STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, lite::PrimitiveC *primitive_c);
STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, lite::PrimitiveC *primitive_c,
const std::string &name);
STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
STATUS ParseQuantParam(const onnx::NodeProto &onnx_node);
STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not);
bool IsSpecialOnnxNode(const onnx::NodeProto &onnx_node);

STATUS ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
const std::string &root_node_name);
STATUS ConvertIfOnnxNode(const onnx::NodeProto &onnx_node, std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
const std::string &root_node_name);
STATUS AddTensorArrayEdge(const FuncGraphPtr &anf_graph, std::vector<AnfNodePtr> *return_new_inputs,
const std::string &loop_node_name, std::vector<AnfNodePtr> *body_graph_inputs,
int act_output_num);
STATUS BuildCondGraph(const FuncGraphPtr &cond_graph, const AnfNodePtr &root_while_node, int inputs_num,
const std::string &cond_graph_name);
STATUS ConvertIfSubgraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
const std::string &subgrah_name, const std::string &if_node_name,
const std::string &root_node_name);
onnx::ModelProto onnx_model_;
onnx::GraphProto onnx_graph_;
std::unordered_map<std::string, AnfNodePtr> nodes_;
FuncGraphPtr func_graph_ptr_ = nullptr;
onnx::GraphProto onnx_root_graph_;
std::vector<FuncGraphPtr> all_subgraphs_;
std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_;
std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_;
std::unordered_map<std::string, std::string> child_root_map_; // for nest control flow node
FuncGraphPtr anf_root_graph_ = nullptr;
};
} // namespace lite
} // namespace mindspore


+ 42
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_nonzero_parser.cc View File

@@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/parser/onnx/onnx_nonzero_parser.h"
#include <memory>
#include "tools/converter/parser/onnx/onnx_model_parser.h"

namespace mindspore {
namespace lite {
lite::PrimitiveC *OnnxNonZeroParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx NonZeroParser";
auto attr = std::make_unique<schema::NonZeroT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_NonZero;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxNonZeroParser("NonZero", new OnnxNonZeroParser());
} // namespace lite
} // namespace mindspore

+ 34
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_nonzero_parser.h View File

@@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NONZERO_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NONZERO_PARSER_H

#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {
class OnnxNonZeroParser : public OnnxNodeParser {
public:
OnnxNonZeroParser() : OnnxNodeParser("NonZero") {}
~OnnxNonZeroParser() override = default;

lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NONZERO_PARSER_H

+ 13
- 0
mindspore/lite/tools/optimizer/graph/identity_remove_pass.cc View File

@@ -96,6 +96,19 @@ bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) {
status = ReplaceIdentity(node, manager);
} else if (type == schema::PrimitiveType_TupleGetItem) {
status = ReplaceTupleGetItem(node, manager);
} else if (type == schema::PrimitiveType_If || type == schema::PrimitiveType_While) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
(void)Run(sub_func_graph);
sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(2));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
(void)Run(sub_func_graph);
}
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "remove identity pass is failed.";


+ 39
- 0
mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc View File

@@ -296,6 +296,45 @@ STATUS OnnxInputAdjustOpPass::AdjustStridedSlice(const FuncGraphPtr &func_graph,
return lite::RET_OK;
}

STATUS OnnxInputAdjustOpPass::AdjustResize(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
auto node = cnode->input(0);
MS_ASSERT(value_node != nullptr);
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
MS_LOG(ERROR) << "cnode input0 is not a valuenode.";
return lite::RET_ERROR;
}
MS_ASSERT(value_node->value() != nullptr);
auto primitive_c = value_node->value()->cast<PrimitiveCPtr>();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "cnode has no primitive_c.";
return lite::RET_ERROR;
}
auto primitive = primitive_c->primitiveT();
if (primitive == nullptr) {
MS_LOG(ERROR) << "cnode has no schema::primitive.";
return lite::RET_ERROR;
}
if (primitive->value.type != schema::PrimitiveType_Resize) {
MS_LOG(DEBUG) << "cnode is not cast node.";
return RET_OK;
}
auto value = primitive->value.value;
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr.";
return lite::RET_ERROR;
}
auto attr = reinterpret_cast<schema::ResizeT *>(value);
if (cnode->inputs().size() > 3 &&
attr->coordinateTransformMode == schema::CoordinateTransformMode_TF_CROP_AND_RESIZE) {
auto new_resize_inputs = cnode->inputs();
new_resize_inputs.erase(new_resize_inputs.begin() + 1);
cnode->set_inputs(new_resize_inputs);
}
return lite::RET_OK;
}

STATUS OnnxInputAdjustOpPass::AdjustConvOrDeConv(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
if (!CheckInputs(cnode)) {


+ 1
- 0
mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h View File

@@ -40,6 +40,7 @@ class OnnxInputAdjustOpPass : public Pass {
STATUS AdjustConvOrDeConv(const CNodePtr &cnode);
STATUS AdjustTile(const CNodePtr &cnode);
STATUS AdjustCast(const CNodePtr &cnode);
STATUS AdjustResize(const CNodePtr &cnode);
STATUS ReplaceConstant(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS ReplaceTransposeWithGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
bool Run(const FuncGraphPtr &func_graph) override;


Loading…
Cancel
Save