diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 0b24e00ca9..3b5e082d46 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -272,6 +272,7 @@ union PrimitiveType { Size, RandomStandardNormal, CropAndResize, + Erf, } enum QuantType: int { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 5ef406aec6..60dd7d6603 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -1260,3 +1260,6 @@ table CropAndResize { method : ResizeMethod; extrapolation_value : float; } + +table Erf { +} \ No newline at end of file diff --git a/mindspore/lite/src/ops/erf.h b/mindspore/lite/src/ops/erf.h new file mode 100644 index 0000000000..a8c9c56038 --- /dev/null +++ b/mindspore/lite/src/ops/erf.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/ops/primitive_c.h" + +#ifndef LITE_MINDSPORE_LITE_C_OPS_ERF_H_ +#define LITE_MINDSPORE_LITE_C_OPS_ERF_H_ + +namespace mindspore { +namespace lite { +class Erf : public PrimitiveC { + public: + MS_DECLARE_PARENT(Erf, PrimitiveC); + Erf() = default; + ~Erf() = default; + explicit Erf(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +}; +} // namespace lite +} // namespace mindspore +#endif // LITE_MINDSPORE_LITE_C_OPS_ERF_H_ diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 60228a74ac..bc3c491f7d 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -169,6 +169,7 @@ #include "src/ops/invert_permutation.h" #include "src/ops/crop_and_resize.h" #include "src/ops/nonzero.h" +#include "src/ops/erf.h" #ifdef SUPPORT_TRAIN #include "src/ops/neg_grad.h" @@ -1028,6 +1029,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) CropAndResize(primitive); case schema::PrimitiveType_NonZero: return new (std::nothrow) NonZero(primitive); + case schema::PrimitiveType_Erf: + return new (std::nothrow) Erf(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: return new (std::nothrow) ActivationGrad(primitive); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_erf_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_erf_parser.cc new file mode 100644 index 0000000000..d8de36ab97 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_erf_parser.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/onnx/onnx_erf_parser.h" +#include + +namespace mindspore { +namespace lite { +lite::PrimitiveC *OnnxErfParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { + MS_LOG(DEBUG) << "onnx ErfParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Erf; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); +} + +OnnxNodeRegistrar g_onnx_erf_parser("Erf", new OnnxErfParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_erf_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_erf_parser.h new file mode 100644 index 0000000000..532337e4ec --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_erf_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ERF_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ERF_PARSER_H + +#include "tools/converter/parser/onnx/onnx_node_parser.h" +#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxErfParser : public OnnxNodeParser { + public: + OnnxErfParser() : OnnxNodeParser("Erf") {} + ~OnnxErfParser() override = default; + + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ERF_PARSER_H diff --git a/mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc b/mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc index 30acda2f04..abb8daaef9 100644 --- a/mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc @@ -19,12 +19,38 @@ #include "tools/optimizer/common/gllo_utils.h" #include "schema/inner/model_generated.h" #include "tools/converter/quantizer/quant_cast.h" +#include "src/common/utils.h" using mindspore::lite::PrimitiveC; namespace mindspore::opt { namespace { constexpr size_t split_inputs_size = 3; +const std::vector single_input_ops = { + schema::PrimitiveType_Reduce, schema::PrimitiveType_ArgMin, schema::PrimitiveType_ArgMax, + schema::PrimitiveType_SpaceToBatch, schema::PrimitiveType_BatchToSpace, schema::PrimitiveType_SpaceToBatchND, + schema::PrimitiveType_BatchToSpaceND, schema::PrimitiveType_SpaceToDepth}; } // namespace + +STATUS ReorderCnodeInputs(CNode *cnode, const std::vector &perm) { + // add primitive first + std::vector new_inputs = {cnode->input(0)}; + auto primitive_c = GetValueNode>(cnode->input(0)); + auto old_quant_params = primitive_c->input_quant_params(); + std::vector> new_quant_params; + // add inputs as perm order + for (size_t idx : perm) { + if (idx > cnode->inputs().size() - 1) { + MS_LOG(ERROR) << "Idx " << idx << " is larger than inputs size: " << cnode->inputs().size() - 1; + return RET_ERROR; + } + new_inputs.emplace_back(cnode->input(idx)); + new_quant_params.emplace_back(old_quant_params.at(idx - 1)); + } + cnode->set_inputs(new_inputs); + primitive_c->set_input_quant_params(new_quant_params); + return RET_OK; +} + bool TfliteInputsOrderExchangePass::Run(const FuncGraphPtr &graph) { auto node_list = TopoSort(graph->get_return()); for (auto &node : node_list) { @@ -33,50 +59,46 @@ bool TfliteInputsOrderExchangePass::Run(const FuncGraphPtr &graph) { } auto cnode = node->cast(); auto primitive_c = GetValueNode>(cnode->input(0)); - if (opt::GetCNodeType(node) == schema::PrimitiveType_DeConv2D) { - cnode->set_input(1, cnode->input(3)); - auto inputs = cnode->inputs(); - inputs.pop_back(); - cnode->set_inputs(inputs); - auto input_quant_params = primitive_c->input_quant_params(); - input_quant_params[0] = input_quant_params.at(2); - input_quant_params.pop_back(); - primitive_c->set_input_quant_params(input_quant_params); + if (opt::GetCNodeType(node) == schema::PrimitiveType_Fill) { + // dims, value => value, dims + if (RET_OK != ReorderCnodeInputs(cnode.get(), {2, 1})) { + MS_LOG(ERROR) << "Reorder fill inputs failed"; + return false; + } continue; } - if (opt::GetCNodeType(node) == schema::PrimitiveType_Split && cnode->inputs().size() == split_inputs_size) { - cnode->set_input(1, cnode->input(2)); - auto inputs = cnode->inputs(); - inputs.pop_back(); - cnode->set_inputs(inputs); + if (opt::GetCNodeType(node) == schema::PrimitiveType_DeConv2D) { + // output_shape, weights, input => input, weight + if (RET_OK != ReorderCnodeInputs(cnode.get(), {3, 2})) { + MS_LOG(ERROR) << "Reorder deconv inputs failed"; + return false; + } + continue; + } - auto input_quant_params = primitive_c->input_quant_params(); - input_quant_params[0] = input_quant_params.at(1); - input_quant_params.pop_back(); - primitive_c->set_input_quant_params(input_quant_params); + if (opt::GetCNodeType(node) == schema::PrimitiveType_Split && cnode->inputs().size() == split_inputs_size) { + // axis, input, ??? => input, axis + if (RET_OK != ReorderCnodeInputs(cnode.get(), {2, 1})) { + MS_LOG(ERROR) << "Reorder split inputs failed"; + return false; + } continue; } - if (opt::GetCNodeType(node) == schema::PrimitiveType_Reduce || - opt::GetCNodeType(node) == schema::PrimitiveType_ArgMin || - opt::GetCNodeType(node) == schema::PrimitiveType_ArgMax || - opt::GetCNodeType(node) == schema::PrimitiveType_SpaceToBatch || - opt::GetCNodeType(node) == schema::PrimitiveType_BatchToSpace || - opt::GetCNodeType(node) == schema::PrimitiveType_SpaceToBatchND || - opt::GetCNodeType(node) == schema::PrimitiveType_BatchToSpaceND || - opt::GetCNodeType(node) == schema::PrimitiveType_SpaceToDepth || - (opt::GetCNodeType(node) == schema::PrimitiveType_Pad && primitive_c->primitiveT()->value.AsPad() != nullptr && - primitive_c->primitiveT()->value.AsPad()->paddingMode == schema::PaddingMode_CONSTANT) || - (opt::GetCNodeType(node) == schema::PrimitiveType_Resize && - primitive_c->primitiveT()->value.AsResize() != nullptr && - primitive_c->primitiveT()->value.AsResize()->newHeight != 0 && - primitive_c->primitiveT()->value.AsResize()->newWidth != 0)) { - std::vector new_inputs; - new_inputs.emplace_back(cnode->input(0)); - new_inputs.emplace_back(cnode->input(1)); - cnode->set_inputs(new_inputs); + bool is_single_input_pad = opt::GetCNodeType(node) == schema::PrimitiveType_Pad && + primitive_c->primitiveT()->value.AsPad() != nullptr && + primitive_c->primitiveT()->value.AsPad()->paddingMode == schema::PaddingMode_CONSTANT; + bool is_single_input_resize = opt::GetCNodeType(node) == schema::PrimitiveType_Resize && + primitive_c->primitiveT()->value.AsResize() != nullptr && + primitive_c->primitiveT()->value.AsResize()->newHeight != 0 && + primitive_c->primitiveT()->value.AsResize()->newWidth != 0; + if (lite::IsContain(single_input_ops, opt::GetCNodeType(node)) || is_single_input_pad || is_single_input_resize) { + if (RET_OK != ReorderCnodeInputs(cnode.get(), {1})) { + MS_LOG(ERROR) << "Reorder single input failed"; + return false; + } continue; } }