From 88cda696e19f224e9a79f610cfc135e90c74b7f0 Mon Sep 17 00:00:00 2001 From: kai00 Date: Thu, 6 Aug 2020 21:03:47 +0800 Subject: [PATCH] add concat transpose populater --- .../src/common/anf_exporter/anf_exporter.cc | 12 +++++ .../anf_populater/anf_activation_populater.cc | 3 ++ .../anf_populater/anf_concat_populater.cc | 45 ++++++++++++++++ .../anf_populater/anf_concat_populater.h | 32 +++++++++++ .../anf_depthwiseconv2d_populater.cc | 20 +++++++ .../anf_populater/anf_node_populater.h | 3 ++ .../anf_populater/anf_transpose_populater.cc | 54 +++++++++++++++++++ .../anf_populater/anf_transpose_populater.h | 29 ++++++++++ 8 files changed, 198 insertions(+) create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.cc create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.h create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.cc create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.h diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc index 031da47514..8441509fc0 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc @@ -244,6 +244,18 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size(); fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); meta_graph->allTensors.emplace_back(std::move(paramTensor)); + } else if (value->isa()) { + auto valueAbstract = valueNode->abstract(); + auto abstractScalar = utils::cast(valueAbstract); + auto typePtr = abstractScalar->GetTypeTrack(); + paramTensor->dataType = typePtr->type_id(); + paramTensor->dims = {1}; + paramTensor->nodeType = schema::NodeType_ValueNode; + auto data = value->cast(); + paramTensor->data.emplace_back(data->value()); + nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size(); + fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); + meta_graph->allTensors.emplace_back(std::move(paramTensor)); } else if (value->isa()) { MS_LOG(INFO) << "Value type is ValueSequence."; break; diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.cc index ab9e1b0cd6..bf6a66e57d 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.cc @@ -29,6 +29,8 @@ int mindspore::lite::AnfActivationPopulater::Parse(mindspore::CNodePtr cnodePtr, attr->type = schema::ActivationType_RELU; } else if (p->name() == "Sigmoid") { attr->type = schema::ActivationType_SIGMOID; + } else if (p->name() == "ReLU6") { + attr->type = schema::ActivationType_RELU6; } node->nodeType = schema::NodeType_CNode; @@ -38,5 +40,6 @@ int mindspore::lite::AnfActivationPopulater::Parse(mindspore::CNodePtr cnodePtr, return 0; } AnfNodePopulaterRegistrar anfReLUParser("ReLU", new AnfActivationPopulater()); +AnfNodePopulaterRegistrar anfReLU6Parser("ReLU6", new AnfActivationPopulater()); AnfNodePopulaterRegistrar anfSigmoidParser("Sigmoid", new AnfActivationPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.cc new file mode 100644 index 0000000000..1b4596205a --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.cc @@ -0,0 +1,45 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/anf_exporter/anf_populater/anf_concat_populater.h" +#include +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfConcatPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto p = GetCNodePrimitive(cnodePtr); + auto attr = std::make_unique(); + + auto prim_axis = GetValue(p->GetAttr("axis")); + attr->axis = prim_axis; + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Concat; + node->primitive->value.value = attr.release(); + + return 0; +} + +AnfNodePopulaterRegistrar anfConcatParser("Concat", new AnfConcatPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.h new file mode 100644 index 0000000000..9a9915dcb5 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_concat_populater.h @@ -0,0 +1,32 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_ANF_CONCAT_PARSER_H +#define MINDSPORE_ANF_CONCAT_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfConcatPopulater : public AnfNodePopulater { + public: + AnfConcatPopulater() = default; + ~AnfConcatPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_CONCAT_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.cc index 5aa9ab1b90..8bb4c79771 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.cc @@ -62,6 +62,26 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP attr->padMode = schema::PadMode_NOTSET; } + auto channel_multiplier = GetValue(p->GetAttr("channel_multiplier")); + attr->channelMultiplier = channel_multiplier; + + MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree); + auto inputNode = cnodePtr->input(kAnfPopulaterTwo); + MS_ASSERT(inputNode != nullptr); + if (inputNode->isa()) { + auto paramNode = inputNode->cast(); + auto abstractBase = paramNode->abstract(); + MS_ASSERT(abstractBase != nullptr); + if (utils::isa(abstractBase)) { + auto abstractTensor = utils::cast(abstractBase); + MS_ASSERT(abstractTensor != nullptr); + if (utils::isa(abstractTensor->BuildShape())) { + auto dims = utils::cast(abstractTensor->BuildShape())->shape(); + attr->channelIn = dims[kAnfPopulaterOne]; + } + } + } + node->nodeType = schema::NodeType_CNode; node->primitive = std::make_unique(); node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.h index e68645090f..3d9accb75e 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.h +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.h @@ -21,6 +21,9 @@ #include "ir/anf.h" #include "schema/inner/model_generated.h" namespace mindspore::lite { +constexpr int kAnfPopulaterOne = 1; +constexpr int kAnfPopulaterTwo = 2; +constexpr int kAnfPopulaterThree = 3; class AnfNodePopulater { public: AnfNodePopulater() = default; diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.cc new file mode 100644 index 0000000000..76eafbec64 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_transpose_populater.h" +#include +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfTransposePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto attr = std::make_unique(); + + MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree); + auto inputNode = cnodePtr->input(kAnfPopulaterTwo); + if (inputNode->isa()) { + auto valNode = inputNode->cast(); + MS_ASSERT(valNode != nullptr); + auto val = valNode->value(); + MS_ASSERT(val != nullptr); + if (val->isa()) { + auto tuple = val->cast(); + MS_ASSERT(tuple != nullptr); + for (size_t i = 0; i < tuple->size(); i++) { + auto elem = tuple->value()[i]->cast(); + MS_ASSERT(elem != nullptr); + attr->perm.emplace_back(static_cast(elem->value())); + } + } + } + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Transpose; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfTransposeParser("Transpose", new AnfTransposePopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.h new file mode 100644 index 0000000000..eecdbb7593 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_transpose_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_TRANSPOSE_PARSER_H +#define MINDSPORE_ANF_TRANSPOSE_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfTransposePopulater : public AnfNodePopulater { + public: + AnfTransposePopulater() = default; + ~AnfTransposePopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_TRANSPOSE_PARSER_H