Browse Source

!11797 fix bug of fill inputs order

From: @cjh9368
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
ae39553df1
7 changed files with 172 additions and 36 deletions
  1. +1
    -0
      mindspore/lite/schema/model.fbs
  2. +3
    -0
      mindspore/lite/schema/ops.fbs
  3. +32
    -0
      mindspore/lite/src/ops/erf.h
  4. +3
    -0
      mindspore/lite/src/ops/primitive_c.cc
  5. +42
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_erf_parser.cc
  6. +33
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_erf_parser.h
  7. +58
    -36
      mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc

+ 1
- 0
mindspore/lite/schema/model.fbs View File

@@ -272,6 +272,7 @@ union PrimitiveType {
Size,
RandomStandardNormal,
CropAndResize,
Erf,
}

enum QuantType: int {


+ 3
- 0
mindspore/lite/schema/ops.fbs View File

@@ -1260,3 +1260,6 @@ table CropAndResize {
method : ResizeMethod;
extrapolation_value : float;
}

table Erf {
}

+ 32
- 0
mindspore/lite/src/ops/erf.h View File

@@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/primitive_c.h"

#ifndef LITE_MINDSPORE_LITE_C_OPS_ERF_H_
#define LITE_MINDSPORE_LITE_C_OPS_ERF_H_

namespace mindspore {
namespace lite {
class Erf : public PrimitiveC {
public:
MS_DECLARE_PARENT(Erf, PrimitiveC);
Erf() = default;
~Erf() = default;
explicit Erf(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_ERF_H_

+ 3
- 0
mindspore/lite/src/ops/primitive_c.cc View File

@@ -169,6 +169,7 @@
#include "src/ops/invert_permutation.h"
#include "src/ops/crop_and_resize.h"
#include "src/ops/nonzero.h"
#include "src/ops/erf.h"

#ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h"
@@ -1028,6 +1029,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) CropAndResize(primitive);
case schema::PrimitiveType_NonZero:
return new (std::nothrow) NonZero(primitive);
case schema::PrimitiveType_Erf:
return new (std::nothrow) Erf(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:
return new (std::nothrow) ActivationGrad(primitive);


+ 42
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_erf_parser.cc View File

@@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_erf_parser.h"
#include <memory>

namespace mindspore {
namespace lite {
lite::PrimitiveC *OnnxErfParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx ErfParser";
auto attr = std::make_unique<schema::ErfT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}

auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Erf;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}

OnnxNodeRegistrar g_onnx_erf_parser("Erf", new OnnxErfParser());
} // namespace lite
} // namespace mindspore

+ 33
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_erf_parser.h View File

@@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ERF_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ERF_PARSER_H

#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {
class OnnxErfParser : public OnnxNodeParser {
public:
OnnxErfParser() : OnnxNodeParser("Erf") {}
~OnnxErfParser() override = default;

lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ERF_PARSER_H

+ 58
- 36
mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc View File

@@ -19,12 +19,38 @@
#include "tools/optimizer/common/gllo_utils.h"
#include "schema/inner/model_generated.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "src/common/utils.h"

using mindspore::lite::PrimitiveC;
namespace mindspore::opt {
namespace {
constexpr size_t split_inputs_size = 3;
const std::vector<schema::PrimitiveType> single_input_ops = {
schema::PrimitiveType_Reduce, schema::PrimitiveType_ArgMin, schema::PrimitiveType_ArgMax,
schema::PrimitiveType_SpaceToBatch, schema::PrimitiveType_BatchToSpace, schema::PrimitiveType_SpaceToBatchND,
schema::PrimitiveType_BatchToSpaceND, schema::PrimitiveType_SpaceToDepth};
} // namespace

STATUS ReorderCnodeInputs(CNode *cnode, const std::vector<size_t> &perm) {
// add primitive first
std::vector<AnfNodePtr> new_inputs = {cnode->input(0)};
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
auto old_quant_params = primitive_c->input_quant_params();
std::vector<std::vector<schema::QuantParamT>> new_quant_params;
// add inputs as perm order
for (size_t idx : perm) {
if (idx > cnode->inputs().size() - 1) {
MS_LOG(ERROR) << "Idx " << idx << " is larger than inputs size: " << cnode->inputs().size() - 1;
return RET_ERROR;
}
new_inputs.emplace_back(cnode->input(idx));
new_quant_params.emplace_back(old_quant_params.at(idx - 1));
}
cnode->set_inputs(new_inputs);
primitive_c->set_input_quant_params(new_quant_params);
return RET_OK;
}

bool TfliteInputsOrderExchangePass::Run(const FuncGraphPtr &graph) {
auto node_list = TopoSort(graph->get_return());
for (auto &node : node_list) {
@@ -33,50 +59,46 @@ bool TfliteInputsOrderExchangePass::Run(const FuncGraphPtr &graph) {
}
auto cnode = node->cast<CNodePtr>();
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (opt::GetCNodeType(node) == schema::PrimitiveType_DeConv2D) {
cnode->set_input(1, cnode->input(3));
auto inputs = cnode->inputs();
inputs.pop_back();
cnode->set_inputs(inputs);

auto input_quant_params = primitive_c->input_quant_params();
input_quant_params[0] = input_quant_params.at(2);
input_quant_params.pop_back();
primitive_c->set_input_quant_params(input_quant_params);
if (opt::GetCNodeType(node) == schema::PrimitiveType_Fill) {
// dims, value => value, dims
if (RET_OK != ReorderCnodeInputs(cnode.get(), {2, 1})) {
MS_LOG(ERROR) << "Reorder fill inputs failed";
return false;
}
continue;
}

if (opt::GetCNodeType(node) == schema::PrimitiveType_Split && cnode->inputs().size() == split_inputs_size) {
cnode->set_input(1, cnode->input(2));
auto inputs = cnode->inputs();
inputs.pop_back();
cnode->set_inputs(inputs);
if (opt::GetCNodeType(node) == schema::PrimitiveType_DeConv2D) {
// output_shape, weights, input => input, weight
if (RET_OK != ReorderCnodeInputs(cnode.get(), {3, 2})) {
MS_LOG(ERROR) << "Reorder deconv inputs failed";
return false;
}
continue;
}

auto input_quant_params = primitive_c->input_quant_params();
input_quant_params[0] = input_quant_params.at(1);
input_quant_params.pop_back();
primitive_c->set_input_quant_params(input_quant_params);
if (opt::GetCNodeType(node) == schema::PrimitiveType_Split && cnode->inputs().size() == split_inputs_size) {
// axis, input, ??? => input, axis
if (RET_OK != ReorderCnodeInputs(cnode.get(), {2, 1})) {
MS_LOG(ERROR) << "Reorder split inputs failed";
return false;
}
continue;
}

if (opt::GetCNodeType(node) == schema::PrimitiveType_Reduce ||
opt::GetCNodeType(node) == schema::PrimitiveType_ArgMin ||
opt::GetCNodeType(node) == schema::PrimitiveType_ArgMax ||
opt::GetCNodeType(node) == schema::PrimitiveType_SpaceToBatch ||
opt::GetCNodeType(node) == schema::PrimitiveType_BatchToSpace ||
opt::GetCNodeType(node) == schema::PrimitiveType_SpaceToBatchND ||
opt::GetCNodeType(node) == schema::PrimitiveType_BatchToSpaceND ||
opt::GetCNodeType(node) == schema::PrimitiveType_SpaceToDepth ||
(opt::GetCNodeType(node) == schema::PrimitiveType_Pad && primitive_c->primitiveT()->value.AsPad() != nullptr &&
primitive_c->primitiveT()->value.AsPad()->paddingMode == schema::PaddingMode_CONSTANT) ||
(opt::GetCNodeType(node) == schema::PrimitiveType_Resize &&
primitive_c->primitiveT()->value.AsResize() != nullptr &&
primitive_c->primitiveT()->value.AsResize()->newHeight != 0 &&
primitive_c->primitiveT()->value.AsResize()->newWidth != 0)) {
std::vector<AnfNodePtr> new_inputs;
new_inputs.emplace_back(cnode->input(0));
new_inputs.emplace_back(cnode->input(1));
cnode->set_inputs(new_inputs);
bool is_single_input_pad = opt::GetCNodeType(node) == schema::PrimitiveType_Pad &&
primitive_c->primitiveT()->value.AsPad() != nullptr &&
primitive_c->primitiveT()->value.AsPad()->paddingMode == schema::PaddingMode_CONSTANT;
bool is_single_input_resize = opt::GetCNodeType(node) == schema::PrimitiveType_Resize &&
primitive_c->primitiveT()->value.AsResize() != nullptr &&
primitive_c->primitiveT()->value.AsResize()->newHeight != 0 &&
primitive_c->primitiveT()->value.AsResize()->newWidth != 0;
if (lite::IsContain(single_input_ops, opt::GetCNodeType(node)) || is_single_input_pad || is_single_input_resize) {
if (RET_OK != ReorderCnodeInputs(cnode.get(), {1})) {
MS_LOG(ERROR) << "Reorder single input failed";
return false;
}
continue;
}
}


Loading…
Cancel
Save