From: @xu_anyue Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -261,6 +261,7 @@ union PrimitiveType { | |||
| Reciprocal, | |||
| Merge, | |||
| Mod, | |||
| GeLU, | |||
| } | |||
| enum QuantType: int { | |||
| @@ -34,7 +34,36 @@ void ArgMax::SetOutMaxValue(bool out_max_value) { this->primitive_->value.AsArgM | |||
| void ArgMax::SetTopK(int top_k) { this->primitive_->value.AsArgMax()->topK = top_k; } | |||
| void ArgMax::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMax()->keepDims = keep_dims; } | |||
| void ArgMax::SetAxisType(int axis_type) { this->primitive_->value.AsArgMax()->axisType = axis_type; } | |||
| int ArgMax::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive error"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_ArgMax; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_ArgMax) { | |||
| MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto argmax_attr = new (std::nothrow) schema::ArgMaxT(); | |||
| if (argmax_attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive value.value error"; | |||
| return RET_ERROR; | |||
| } | |||
| if (prim.GetAttr("axis") != nullptr) { | |||
| argmax_attr->axis = static_cast<int32_t>(GetValue<int64_t>(prim.GetAttr("axis"))); | |||
| } | |||
| if (prim.GetAttr("keep_dims") != nullptr) { | |||
| argmax_attr->keepDims = static_cast<bool>(GetValue<bool>(prim.GetAttr("keep_dims"))); | |||
| } | |||
| argmax_attr->outMaxValue = false; | |||
| this->primitive_->value.value = argmax_attr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int ArgMax::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| @@ -37,6 +37,7 @@ class ArgMax : public PrimitiveC { | |||
| void SetTopK(int top_k); | |||
| void SetKeepDims(bool keep_dims); | |||
| void SetAxisType(int axis_type); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| @@ -61,6 +61,7 @@ int ArgMin::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||
| if (prim.GetAttr("keep_dims") != nullptr) { | |||
| attr->keepDims = static_cast<bool>(GetValue<bool>(prim.GetAttr("keep_dims"))); | |||
| } | |||
| attr->outMaxValue = false; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/gelu.h" | |||
| #include <memory> | |||
| #include "include/errorcode.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/tensor.h" | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int GeLU::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_GeLU; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_GeLU) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| this->primitive_->value.value = new (std::nothrow) schema::GeLUT(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #endif | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_GELU_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_GELU_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class GeLU : public PrimitiveC { | |||
| public: | |||
| GeLU() = default; | |||
| ~GeLU() = default; | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(GeLU, PrimitiveC); | |||
| explicit GeLU(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #endif | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_GELU_H_ | |||
| @@ -35,7 +35,42 @@ void LayerNorm::SetEpsilon(float epsilon) { this->primitive_->value.AsLayerNorm( | |||
| void LayerNorm::SetElementwiseAffine(bool elementwiseAffine) { | |||
| this->primitive_->value.AsLayerNorm()->elementwiseAffine = elementwiseAffine; | |||
| } | |||
| int LayerNorm::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive error"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_LayerNorm; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_LayerNorm) { | |||
| MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto layer_norm_attr = new (std::nothrow) schema::LayerNormT(); | |||
| if (layer_norm_attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive value.value error"; | |||
| return RET_ERROR; | |||
| } | |||
| auto value_attr = prim.GetAttr("epsilon"); | |||
| if (value_attr != nullptr) { | |||
| layer_norm_attr->epsilon = GetValue<float>(value_attr); | |||
| } else { | |||
| layer_norm_attr->epsilon = 1e-7; | |||
| } | |||
| value_attr = prim.GetAttr("normalized_shape"); | |||
| if (value_attr != nullptr) { | |||
| layer_norm_attr->normalizedShape = CastToInt(value_attr); | |||
| } | |||
| if (inputs.size() == 3) { | |||
| layer_norm_attr->elementwiseAffine = true; | |||
| } | |||
| this->primitive_->value.value = layer_norm_attr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int LayerNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| @@ -100,13 +135,12 @@ int LayerNorm::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite: | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (normlized_shape_.empty()) { | |||
| // instance norm -> layernorm | |||
| // instance norm -> layernorm only for nchw | |||
| if (input->format() == schema::Format_NCHW) { | |||
| normlized_shape_.insert(normlized_shape_.begin(), input_shape.begin() + 2, input_shape.end()); | |||
| elementwise_mode_ = 1; | |||
| } else { | |||
| MS_LOG(INFO) << "normalized_shape attr invalid"; | |||
| return RET_PARAM_INVALID; | |||
| normlized_shape_.insert(normlized_shape_.begin(), input_shape.begin() + 1, input_shape.end()); | |||
| } | |||
| } | |||
| size_t first_index = input_shape.size() - normlized_shape_.size(); | |||
| @@ -35,6 +35,7 @@ class LayerNorm : public PrimitiveC { | |||
| void SetNormalizedShape(const std::vector<int> &normalizedShape); | |||
| void SetEpsilon(float epsilon); | |||
| void SetElementwiseAffine(bool elementwiseAffine); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| @@ -160,6 +160,7 @@ | |||
| #include "src/ops/merge.h" | |||
| #include "src/ops/switch.h" | |||
| #include "src/ops/partial.h" | |||
| #include "src/ops/gelu.h" | |||
| #ifdef SUPPORT_TRAIN | |||
| #include "src/ops/neg_grad.h" | |||
| @@ -330,9 +331,28 @@ void PrimitiveC::PopulaterOutputQuantParam(const Primitive &prim, bool narrowRan | |||
| void PrimitiveC::PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| auto narrow_range = prim.GetAttr("narrow_range"); | |||
| bool narrowRangeQuantParam = narrow_range != nullptr && GetValue<bool>(narrow_range); | |||
| bool narrowRangeQuantParam = false; | |||
| if (narrow_range != nullptr) { | |||
| if (utils::isa<tensor::TensorPtr>(narrow_range)) { | |||
| auto narrow_range_tensor = narrow_range->cast<tensor::TensorPtr>(); | |||
| narrowRangeQuantParam = *reinterpret_cast<bool *>(narrow_range_tensor->data_c()); | |||
| } else if (utils::isa<ImmTraits<bool>::type>(narrow_range)) { | |||
| narrowRangeQuantParam = GetValue<bool>(narrow_range); | |||
| } else { | |||
| MS_LOG(ERROR) << "valueptr is invalid."; | |||
| return; | |||
| } | |||
| } | |||
| auto num_bits = prim.GetAttr("num_bits"); | |||
| int32_t numbitsRangeQuantParam = num_bits != nullptr ? GetValue<int64_t>(num_bits) : 8; | |||
| int32_t numbitsRangeQuantParam = 8; | |||
| if (num_bits != nullptr) { | |||
| if (utils::isa<tensor::TensorPtr>(num_bits)) { | |||
| auto num_bits_tensor = num_bits->cast<tensor::TensorPtr>(); | |||
| numbitsRangeQuantParam = *reinterpret_cast<int64_t *>(num_bits_tensor->data_c()); | |||
| } else if (utils::isa<ImmTraits<int64_t>::type>(num_bits)) { | |||
| numbitsRangeQuantParam = GetValue<int64_t>(num_bits); | |||
| } | |||
| } | |||
| PopulaterInputQuantParam(prim, inputs, narrowRangeQuantParam, numbitsRangeQuantParam); | |||
| PopulaterOutputQuantParam(prim, narrowRangeQuantParam, numbitsRangeQuantParam); | |||
| } | |||
| @@ -511,7 +531,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<FusedBatchNorm>(prim, inputs, quantType); | |||
| } else if (op_type == "make_tuple") { | |||
| return NewPrimitiveC<MakeTuple>(prim, inputs, quantType); | |||
| } else if (op_type == "MatMul") { | |||
| } else if (op_type == "MatMul" || op_type == "BatchMatMul") { | |||
| return NewPrimitiveC<MatMul>(prim, inputs, quantType); | |||
| } else if (op_type == "Mul") { | |||
| return NewPrimitiveC<Mul>(prim, inputs, quantType); | |||
| @@ -601,7 +621,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<TopK>(prim, inputs, quantType); | |||
| } else if (op_type == "Mod") { | |||
| return NewPrimitiveC<Mod>(prim, inputs, quantType); | |||
| } else if (op_type == "ArgMinWithValue") { | |||
| } else if (op_type == "ArgMin" || op_type == "ArgMinWithValue") { | |||
| return NewPrimitiveC<ArgMin>(prim, inputs, quantType); | |||
| } else if (op_type == "Range") { | |||
| return NewPrimitiveC<Range>(prim, inputs, quantType); | |||
| @@ -621,6 +641,12 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<Partial>(prim, inputs, quantType); | |||
| } else if (op_type == "Merge") { | |||
| return NewPrimitiveC<Merge>(prim, inputs, quantType); | |||
| } else if (op_type == "LayerNorm") { | |||
| return NewPrimitiveC<LayerNorm>(prim, inputs, quantType); | |||
| } else if (op_type == "ArgMax" || op_type == "ArgMaxWithValue") { | |||
| return NewPrimitiveC<ArgMax>(prim, inputs, quantType); | |||
| } else if (op_type == "Gelu") { | |||
| return NewPrimitiveC<GeLU>(prim, inputs, quantType); | |||
| #ifdef SUPPORT_TRAIN | |||
| } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | |||
| @@ -965,6 +991,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new (std::nothrow) Partial(primitive); | |||
| case schema::PrimitiveType_Assert: | |||
| return new (std::nothrow) AssertOP(primitive); | |||
| case schema::PrimitiveType_GeLU: | |||
| return new (std::nothrow) GeLU(primitive); | |||
| #ifdef SUPPORT_TRAIN | |||
| case schema::PrimitiveType_ActivationGrad: | |||
| return new (std::nothrow) ActivationGrad(primitive); | |||
| @@ -199,6 +199,7 @@ if(ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/mindir_inputs_adjust_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/while_pass.cc | |||
| ) | |||
| @@ -737,7 +737,8 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||
| meta_graphT->allTensors.emplace_back(msTensor); | |||
| if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm)) { | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm) || | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_LayerNorm)) { | |||
| break; | |||
| } | |||
| #endif | |||
| @@ -60,6 +60,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/graph/mindir_adjust_pass.cc | |||
| ../optimizer/graph/onnx_inputs_adjust_pass.cc | |||
| ../optimizer/graph/while_pass.cc | |||
| ../optimizer/graph/mindir_inputs_adjust_pass.cc | |||
| ) | |||
| add_subdirectory(../anf_importer anf_importer) | |||
| @@ -30,6 +30,7 @@ | |||
| #include "tools/optimizer/fusion/sigmoid_mul_fusion.h" | |||
| #include "tools/optimizer/fusion/conv_conv_fusion.h" | |||
| #include "tools/optimizer/graph/mindir_adjust_pass.h" | |||
| #include "tools/optimizer/graph/mindir_inputs_adjust_pass.h" | |||
| #include "tools/optimizer/graph/identity_remove_pass.h" | |||
| #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | |||
| #include "tools/optimizer/graph/weight_format_transform_pass.h" | |||
| @@ -77,6 +78,12 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| auto mindir_inputs_adjust_pass = std::make_shared<opt::MindirInputAdjustOpPass>(); | |||
| if (!mindir_inputs_adjust_pass->Run(old_graph)) { | |||
| MS_LOG(ERROR) << "mindir inputs adjust failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| } | |||
| // onnx pre adjustment | |||
| @@ -0,0 +1,236 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/graph/mindir_inputs_adjust_pass.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/tensor.h" | |||
| using mindspore::lite::PrimitiveC; | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| template <typename T> | |||
| void CopyAttrForArgMinMax(T *left, T *right) { | |||
| MS_ASSERT(left != null && right != nullptr); | |||
| left->axis = right->axis; | |||
| left->outMaxValue = right->outMaxValue; | |||
| left->axisType = right->axisType; | |||
| left->keepDims = right->keepDims; | |||
| left->topK = right->topK; | |||
| } | |||
| } // namespace | |||
| bool MindirInputAdjustOpPass::CheckCNodeIsArgMinMax(const CNodePtr &cnode) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto prim_node = cnode->inputs().at(0); | |||
| MS_ASSERT(prim_node != nullptr); | |||
| auto prim_value_node = prim_node->cast<ValueNodePtr>(); | |||
| if (prim_value_node == nullptr) { | |||
| MS_LOG(DEBUG) << "cnode first input is not valueNode."; | |||
| return false; | |||
| } | |||
| auto value = prim_value_node->value(); | |||
| MS_ASSERT(value != nullptr); | |||
| auto prim_c = value->cast<PrimitiveCPtr>(); | |||
| if (prim_c == nullptr) { | |||
| MS_LOG(DEBUG) << "prim is not primitiveC."; | |||
| return false; | |||
| } | |||
| auto prim = prim_c->primitiveT(); | |||
| MS_ASSERT(prim != nullptr); | |||
| return prim->value.type == schema::PrimitiveType_ArgMax || prim->value.type == schema::PrimitiveType_ArgMin; | |||
| } | |||
| int MindirInputAdjustOpPass::AdjustArgMinMaxInputs(std::vector<AnfNodePtr> *inputs, bool index_or_value) { | |||
| MS_ASSERT(inputs != nullptr); | |||
| auto prim_node = inputs->at(0); | |||
| MS_ASSERT(prim_node != nullptr); | |||
| auto prim_value_node = prim_node->cast<ValueNodePtr>(); | |||
| if (prim_value_node == nullptr) { | |||
| MS_LOG(ERROR) << "cnode first input is not valueNode."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto prim_value = prim_value_node->value(); | |||
| if (prim_value == nullptr) { | |||
| MS_LOG(ERROR) << "valueNode value is nullptr."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto prim_c = prim_value->cast<PrimitiveCPtr>(); | |||
| if (prim_c == nullptr) { | |||
| MS_LOG(ERROR) << "value is not primitiveC."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto prim = prim_c->primitiveT(); | |||
| MS_ASSERT(prim != nullptr && prim->value.value != nullptr); | |||
| auto attr = prim->value.value; | |||
| if (prim->value.type == schema::PrimitiveType_ArgMax) { | |||
| reinterpret_cast<schema::ArgMaxT *>(attr)->outMaxValue = index_or_value; | |||
| } else if (prim->value.type == schema::PrimitiveType_ArgMin) { | |||
| reinterpret_cast<schema::ArgMinT *>(attr)->outMaxValue = index_or_value; | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| int MindirInputAdjustOpPass::CopyPrimitiveCForArgMinMax(std::vector<AnfNodePtr> *inputs) { | |||
| MS_ASSERT(inputs != nullptr); | |||
| auto prim_node = inputs->at(0); | |||
| MS_ASSERT(prim_node != nullptr); | |||
| auto prim_value_node = prim_node->cast<ValueNodePtr>(); | |||
| if (prim_value_node == nullptr) { | |||
| MS_LOG(ERROR) << "cnode first input is not valueNode."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto prim_value = prim_value_node->value(); | |||
| if (prim_value == nullptr) { | |||
| MS_LOG(ERROR) << "valueNode value is nullptr."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto prim_c = prim_value->cast<PrimitiveCPtr>(); | |||
| if (prim_c == nullptr) { | |||
| MS_LOG(ERROR) << "value is not primitiveC."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto prim = prim_c->primitiveT(); | |||
| MS_ASSERT(prim != nullptr && prim->value.value != nullptr); | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (prim->value.type == schema::PrimitiveType_ArgMax) { | |||
| primitive->value.type = schema::PrimitiveType_ArgMax; | |||
| auto attr = std::make_unique<schema::ArgMaxT>(); | |||
| CopyAttrForArgMinMax<schema::ArgMaxT>(attr.get(), reinterpret_cast<schema::ArgMaxT *>(prim->value.value)); | |||
| primitive->value.value = attr.release(); | |||
| } else { | |||
| primitive->value.type = schema::PrimitiveType_ArgMin; | |||
| auto attr = std::make_unique<schema::ArgMinT>(); | |||
| CopyAttrForArgMinMax<schema::ArgMinT>(attr.get(), reinterpret_cast<schema::ArgMinT *>(prim->value.value)); | |||
| primitive->value.value = attr.release(); | |||
| } | |||
| auto primitive_c = PrimitiveC::Create(primitive.release()); | |||
| auto value_node = NewValueNode(std::shared_ptr<PrimitiveC>(primitive_c)); | |||
| inputs->erase(inputs->begin()); | |||
| inputs->insert(inputs->begin(), value_node); | |||
| return lite::RET_OK; | |||
| } | |||
| int MindirInputAdjustOpPass::BuildCNodeForArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item, | |||
| const CNodePtr &argmin_max) { | |||
| MS_ASSERT(graph != nullptr && tuple_get_item != nullptr && argmin_max != nullptr); | |||
| auto inputs = argmin_max->inputs(); | |||
| if (CopyPrimitiveCForArgMinMax(&inputs) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "copy argmin or argmax failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (AdjustArgMinMaxInputs(&inputs, false) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "adjust argmin or argmax attr failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto new_cnode = graph->NewCNode(inputs); | |||
| new_cnode->set_fullname_with_scope(argmin_max->fullname_with_scope() + "_index"); | |||
| auto type_ptr = TypeIdToType(kTypeUnknown); | |||
| std::vector<int64_t> shape_vector; | |||
| new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | |||
| auto manager = graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| manager->Replace(tuple_get_item, new_cnode); | |||
| return lite::RET_OK; | |||
| } | |||
| int MindirInputAdjustOpPass::AdjustArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item, | |||
| const CNodePtr &argmin_max) { | |||
| MS_ASSERT(graph != nullptr && tuple_get_item != nullptr && argmin_max != nullptr); | |||
| auto inputs = argmin_max->inputs(); | |||
| if (AdjustArgMinMaxInputs(&inputs, true) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "adjust argmin or argmax attr failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto type_ptr = TypeIdToType(kTypeUnknown); | |||
| std::vector<int64_t> shape_vector; | |||
| auto abtract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| argmin_max->set_abstract(abtract_tensor); | |||
| auto manager = graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| manager->Replace(tuple_get_item, argmin_max); | |||
| return lite::RET_OK; | |||
| } | |||
| int MindirInputAdjustOpPass::AdjustTupleGetItemWithArgMinMax(const FuncGraphPtr &graph, const CNodePtr &cnode) { | |||
| MS_ASSERT(graph != nullptr && cnode != nullptr); | |||
| auto inputs = cnode->inputs(); | |||
| if (inputs.size() != 3) { | |||
| MS_LOG(ERROR) << "tupleGetItem inputs size is invalid: " << inputs.size(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto argmin_max = inputs.at(1); | |||
| MS_ASSERT(argmin_max != nullptr); | |||
| auto argmin_max_cnode = argmin_max->cast<CNodePtr>(); | |||
| if (argmin_max_cnode == nullptr) { | |||
| MS_LOG(ERROR) << "the second input is not a cnode."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (!CheckCNodeIsArgMinMax(argmin_max_cnode)) { | |||
| MS_LOG(DEBUG) << "tuple_get_item first input is not argmin and argmax."; | |||
| return lite::RET_OK; | |||
| } | |||
| auto index_vnode = inputs.at(2); | |||
| auto value_node = index_vnode->cast<ValueNodePtr>(); | |||
| if (value_node == nullptr) { | |||
| MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode"; | |||
| return lite::RET_ERROR; | |||
| } | |||
| int index = lite::CastToInt(value_node->value()).front(); | |||
| if (index == 0) { | |||
| if (BuildCNodeForArgMinMax(graph, cnode, argmin_max_cnode) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "build new cnode failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| } else if (index == 1) { | |||
| if (AdjustArgMinMax(graph, cnode, argmin_max_cnode) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "adjust argmin_max failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| bool MindirInputAdjustOpPass::Run(const FuncGraphPtr &graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| auto manager = Manage(graph, true); | |||
| if (manager == nullptr) { | |||
| MS_LOG(ERROR) << "manager is nullptr."; | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| auto node_list = TopoSort(graph->get_return()); | |||
| int status = lite::RET_OK; | |||
| for (auto &node : node_list) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| MS_LOG(DEBUG) << "node is not cnode."; | |||
| continue; | |||
| } | |||
| auto type = opt::GetCNodeType(node); | |||
| if (type == schema::PrimitiveType_TupleGetItem) { | |||
| status = AdjustTupleGetItemWithArgMinMax(graph, cnode); | |||
| } | |||
| if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "adjust input pass is failed."; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_INPUTS_ADJUST_PASS_H_ | |||
| #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_INPUTS_ADJUST_PASS_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "src/param_value_lite.h" | |||
| namespace mindspore::opt { | |||
| class MindirInputAdjustOpPass : public Pass { | |||
| public: | |||
| MindirInputAdjustOpPass() : Pass("mindir_inputs_adjust_pass") {} | |||
| ~MindirInputAdjustOpPass() override = default; | |||
| bool CheckCNodeIsArgMinMax(const CNodePtr &cnode); | |||
| int AdjustArgMinMaxInputs(std::vector<AnfNodePtr> *inputs, bool index_or_value); | |||
| int CopyPrimitiveCForArgMinMax(std::vector<AnfNodePtr> *inputs); | |||
| int BuildCNodeForArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item, const CNodePtr &argmin_max); | |||
| int AdjustArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item, const CNodePtr &argmin_max); | |||
| int AdjustTupleGetItemWithArgMinMax(const FuncGraphPtr &graph, const CNodePtr &cnode); | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_INPUTS_ADJUST_PASS_H_ | |||
| @@ -43,8 +43,6 @@ class OnnxInputAdjustOpPass : public Pass { | |||
| STATUS ReplaceConstant(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS ReplaceTransposeWithGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| private: | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_ONNX_INPUTS_ADJUST_PASS_H_ | |||