Browse Source

!10324 [lite]add argmax、layernorm、batchmatmul for minidr

From: @xu_anyue
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
00455b9559
16 changed files with 485 additions and 12 deletions
  1. +1
    -0
      mindspore/lite/schema/model.fbs
  2. +30
    -1
      mindspore/lite/src/ops/argmax.cc
  3. +1
    -0
      mindspore/lite/src/ops/argmax.h
  4. +1
    -0
      mindspore/lite/src/ops/argmin.cc
  5. +53
    -0
      mindspore/lite/src/ops/gelu.cc
  6. +40
    -0
      mindspore/lite/src/ops/gelu.h
  7. +38
    -4
      mindspore/lite/src/ops/layer_norm.cc
  8. +1
    -0
      mindspore/lite/src/ops/layer_norm.h
  9. +32
    -4
      mindspore/lite/src/ops/primitive_c.cc
  10. +1
    -0
      mindspore/lite/test/CMakeLists.txt
  11. +2
    -1
      mindspore/lite/tools/anf_exporter/anf_exporter.cc
  12. +1
    -0
      mindspore/lite/tools/converter/CMakeLists.txt
  13. +7
    -0
      mindspore/lite/tools/converter/anf_transform.cc
  14. +236
    -0
      mindspore/lite/tools/optimizer/graph/mindir_inputs_adjust_pass.cc
  15. +41
    -0
      mindspore/lite/tools/optimizer/graph/mindir_inputs_adjust_pass.h
  16. +0
    -2
      mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h

+ 1
- 0
mindspore/lite/schema/model.fbs View File

@@ -261,6 +261,7 @@ union PrimitiveType {
Reciprocal,
Merge,
Mod,
GeLU,
}

enum QuantType: int {


+ 30
- 1
mindspore/lite/src/ops/argmax.cc View File

@@ -34,7 +34,36 @@ void ArgMax::SetOutMaxValue(bool out_max_value) { this->primitive_->value.AsArgM
void ArgMax::SetTopK(int top_k) { this->primitive_->value.AsArgMax()->topK = top_k; }
void ArgMax::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMax()->keepDims = keep_dims; }
void ArgMax::SetAxisType(int axis_type) { this->primitive_->value.AsArgMax()->axisType = axis_type; }

int ArgMax::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitive error";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_ArgMax;
}
if (this->primitive_->value.type != schema::PrimitiveType_ArgMax) {
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto argmax_attr = new (std::nothrow) schema::ArgMaxT();
if (argmax_attr == nullptr) {
MS_LOG(ERROR) << "new primitive value.value error";
return RET_ERROR;
}
if (prim.GetAttr("axis") != nullptr) {
argmax_attr->axis = static_cast<int32_t>(GetValue<int64_t>(prim.GetAttr("axis")));
}
if (prim.GetAttr("keep_dims") != nullptr) {
argmax_attr->keepDims = static_cast<bool>(GetValue<bool>(prim.GetAttr("keep_dims")));
}
argmax_attr->outMaxValue = false;
this->primitive_->value.value = argmax_attr;
}
return RET_OK;
}
#else
int ArgMax::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);


+ 1
- 0
mindspore/lite/src/ops/argmax.h View File

@@ -37,6 +37,7 @@ class ArgMax : public PrimitiveC {
void SetTopK(int top_k);
void SetKeepDims(bool keep_dims);
void SetAxisType(int axis_type);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif


+ 1
- 0
mindspore/lite/src/ops/argmin.cc View File

@@ -61,6 +61,7 @@ int ArgMin::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
if (prim.GetAttr("keep_dims") != nullptr) {
attr->keepDims = static_cast<bool>(GetValue<bool>(prim.GetAttr("keep_dims")));
}
attr->outMaxValue = false;
}
return RET_OK;
}


+ 53
- 0
mindspore/lite/src/ops/gelu.cc View File

@@ -0,0 +1,53 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/gelu.h"
#include <memory>
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "src/tensor.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int GeLU::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_GeLU;
}
if (this->primitive_->value.type != schema::PrimitiveType_GeLU) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
this->primitive_->value.value = new (std::nothrow) schema::GeLUT();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
return RET_OK;
}
#endif
} // namespace lite
} // namespace mindspore

+ 40
- 0
mindspore/lite/src/ops/gelu.h View File

@@ -0,0 +1,40 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_GELU_H_
#define LITE_MINDSPORE_LITE_C_OPS_GELU_H_

#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
class GeLU : public PrimitiveC {
public:
GeLU() = default;
~GeLU() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(GeLU, PrimitiveC);
explicit GeLU(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#endif
};
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_C_OPS_GELU_H_

+ 38
- 4
mindspore/lite/src/ops/layer_norm.cc View File

@@ -35,7 +35,42 @@ void LayerNorm::SetEpsilon(float epsilon) { this->primitive_->value.AsLayerNorm(
void LayerNorm::SetElementwiseAffine(bool elementwiseAffine) {
this->primitive_->value.AsLayerNorm()->elementwiseAffine = elementwiseAffine;
}

int LayerNorm::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitive error";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_LayerNorm;
}
if (this->primitive_->value.type != schema::PrimitiveType_LayerNorm) {
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto layer_norm_attr = new (std::nothrow) schema::LayerNormT();
if (layer_norm_attr == nullptr) {
MS_LOG(ERROR) << "new primitive value.value error";
return RET_ERROR;
}
auto value_attr = prim.GetAttr("epsilon");
if (value_attr != nullptr) {
layer_norm_attr->epsilon = GetValue<float>(value_attr);
} else {
layer_norm_attr->epsilon = 1e-7;
}
value_attr = prim.GetAttr("normalized_shape");
if (value_attr != nullptr) {
layer_norm_attr->normalizedShape = CastToInt(value_attr);
}
if (inputs.size() == 3) {
layer_norm_attr->elementwiseAffine = true;
}
this->primitive_->value.value = layer_norm_attr;
}
return RET_OK;
}
#else
int LayerNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
@@ -100,13 +135,12 @@ int LayerNorm::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite:
return RET_PARAM_INVALID;
}
if (normlized_shape_.empty()) {
// instance norm -> layernorm
// instance norm -> layernorm only for nchw
if (input->format() == schema::Format_NCHW) {
normlized_shape_.insert(normlized_shape_.begin(), input_shape.begin() + 2, input_shape.end());
elementwise_mode_ = 1;
} else {
MS_LOG(INFO) << "normalized_shape attr invalid";
return RET_PARAM_INVALID;
normlized_shape_.insert(normlized_shape_.begin(), input_shape.begin() + 1, input_shape.end());
}
}
size_t first_index = input_shape.size() - normlized_shape_.size();


+ 1
- 0
mindspore/lite/src/ops/layer_norm.h View File

@@ -35,6 +35,7 @@ class LayerNorm : public PrimitiveC {
void SetNormalizedShape(const std::vector<int> &normalizedShape);
void SetEpsilon(float epsilon);
void SetElementwiseAffine(bool elementwiseAffine);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif


+ 32
- 4
mindspore/lite/src/ops/primitive_c.cc View File

@@ -160,6 +160,7 @@
#include "src/ops/merge.h"
#include "src/ops/switch.h"
#include "src/ops/partial.h"
#include "src/ops/gelu.h"

#ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h"
@@ -330,9 +331,28 @@ void PrimitiveC::PopulaterOutputQuantParam(const Primitive &prim, bool narrowRan

void PrimitiveC::PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
auto narrow_range = prim.GetAttr("narrow_range");
bool narrowRangeQuantParam = narrow_range != nullptr && GetValue<bool>(narrow_range);
bool narrowRangeQuantParam = false;
if (narrow_range != nullptr) {
if (utils::isa<tensor::TensorPtr>(narrow_range)) {
auto narrow_range_tensor = narrow_range->cast<tensor::TensorPtr>();
narrowRangeQuantParam = *reinterpret_cast<bool *>(narrow_range_tensor->data_c());
} else if (utils::isa<ImmTraits<bool>::type>(narrow_range)) {
narrowRangeQuantParam = GetValue<bool>(narrow_range);
} else {
MS_LOG(ERROR) << "valueptr is invalid.";
return;
}
}
auto num_bits = prim.GetAttr("num_bits");
int32_t numbitsRangeQuantParam = num_bits != nullptr ? GetValue<int64_t>(num_bits) : 8;
int32_t numbitsRangeQuantParam = 8;
if (num_bits != nullptr) {
if (utils::isa<tensor::TensorPtr>(num_bits)) {
auto num_bits_tensor = num_bits->cast<tensor::TensorPtr>();
numbitsRangeQuantParam = *reinterpret_cast<int64_t *>(num_bits_tensor->data_c());
} else if (utils::isa<ImmTraits<int64_t>::type>(num_bits)) {
numbitsRangeQuantParam = GetValue<int64_t>(num_bits);
}
}
PopulaterInputQuantParam(prim, inputs, narrowRangeQuantParam, numbitsRangeQuantParam);
PopulaterOutputQuantParam(prim, narrowRangeQuantParam, numbitsRangeQuantParam);
}
@@ -511,7 +531,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<FusedBatchNorm>(prim, inputs, quantType);
} else if (op_type == "make_tuple") {
return NewPrimitiveC<MakeTuple>(prim, inputs, quantType);
} else if (op_type == "MatMul") {
} else if (op_type == "MatMul" || op_type == "BatchMatMul") {
return NewPrimitiveC<MatMul>(prim, inputs, quantType);
} else if (op_type == "Mul") {
return NewPrimitiveC<Mul>(prim, inputs, quantType);
@@ -601,7 +621,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<TopK>(prim, inputs, quantType);
} else if (op_type == "Mod") {
return NewPrimitiveC<Mod>(prim, inputs, quantType);
} else if (op_type == "ArgMinWithValue") {
} else if (op_type == "ArgMin" || op_type == "ArgMinWithValue") {
return NewPrimitiveC<ArgMin>(prim, inputs, quantType);
} else if (op_type == "Range") {
return NewPrimitiveC<Range>(prim, inputs, quantType);
@@ -621,6 +641,12 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<Partial>(prim, inputs, quantType);
} else if (op_type == "Merge") {
return NewPrimitiveC<Merge>(prim, inputs, quantType);
} else if (op_type == "LayerNorm") {
return NewPrimitiveC<LayerNorm>(prim, inputs, quantType);
} else if (op_type == "ArgMax" || op_type == "ArgMaxWithValue") {
return NewPrimitiveC<ArgMax>(prim, inputs, quantType);
} else if (op_type == "Gelu") {
return NewPrimitiveC<GeLU>(prim, inputs, quantType);

#ifdef SUPPORT_TRAIN
} else if (op_type == "SoftmaxCrossEntropyWithLogits") {
@@ -965,6 +991,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) Partial(primitive);
case schema::PrimitiveType_Assert:
return new (std::nothrow) AssertOP(primitive);
case schema::PrimitiveType_GeLU:
return new (std::nothrow) GeLU(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:
return new (std::nothrow) ActivationGrad(primitive);


+ 1
- 0
mindspore/lite/test/CMakeLists.txt View File

@@ -199,6 +199,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc
${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc
${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/mindir_inputs_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/while_pass.cc
)


+ 2
- 1
mindspore/lite/tools/anf_exporter/anf_exporter.cc View File

@@ -737,7 +737,8 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
meta_graphT->allTensors.emplace_back(msTensor);
if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm)) {
IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_LayerNorm)) {
break;
}
#endif


+ 1
- 0
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -60,6 +60,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/mindir_adjust_pass.cc
../optimizer/graph/onnx_inputs_adjust_pass.cc
../optimizer/graph/while_pass.cc
../optimizer/graph/mindir_inputs_adjust_pass.cc
)

add_subdirectory(../anf_importer anf_importer)


+ 7
- 0
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -30,6 +30,7 @@
#include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
#include "tools/optimizer/fusion/conv_conv_fusion.h"
#include "tools/optimizer/graph/mindir_adjust_pass.h"
#include "tools/optimizer/graph/mindir_inputs_adjust_pass.h"
#include "tools/optimizer/graph/identity_remove_pass.h"
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/optimizer/graph/weight_format_transform_pass.h"
@@ -77,6 +78,12 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
auto mindir_inputs_adjust_pass = std::make_shared<opt::MindirInputAdjustOpPass>();
if (!mindir_inputs_adjust_pass->Run(old_graph)) {
MS_LOG(ERROR) << "mindir inputs adjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
}

// onnx pre adjustment


+ 236
- 0
mindspore/lite/tools/optimizer/graph/mindir_inputs_adjust_pass.cc View File

@@ -0,0 +1,236 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/mindir_inputs_adjust_pass.h"
#include <vector>
#include <memory>
#include "src/common/log_adapter.h"
#include "src/ops/primitive_c.h"
#include "src/tensor.h"

using mindspore::lite::PrimitiveC;
namespace mindspore {
namespace opt {
namespace {
template <typename T>
void CopyAttrForArgMinMax(T *left, T *right) {
MS_ASSERT(left != null && right != nullptr);
left->axis = right->axis;
left->outMaxValue = right->outMaxValue;
left->axisType = right->axisType;
left->keepDims = right->keepDims;
left->topK = right->topK;
}
} // namespace

bool MindirInputAdjustOpPass::CheckCNodeIsArgMinMax(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
auto prim_node = cnode->inputs().at(0);
MS_ASSERT(prim_node != nullptr);
auto prim_value_node = prim_node->cast<ValueNodePtr>();
if (prim_value_node == nullptr) {
MS_LOG(DEBUG) << "cnode first input is not valueNode.";
return false;
}
auto value = prim_value_node->value();
MS_ASSERT(value != nullptr);
auto prim_c = value->cast<PrimitiveCPtr>();
if (prim_c == nullptr) {
MS_LOG(DEBUG) << "prim is not primitiveC.";
return false;
}
auto prim = prim_c->primitiveT();
MS_ASSERT(prim != nullptr);
return prim->value.type == schema::PrimitiveType_ArgMax || prim->value.type == schema::PrimitiveType_ArgMin;
}

int MindirInputAdjustOpPass::AdjustArgMinMaxInputs(std::vector<AnfNodePtr> *inputs, bool index_or_value) {
MS_ASSERT(inputs != nullptr);
auto prim_node = inputs->at(0);
MS_ASSERT(prim_node != nullptr);
auto prim_value_node = prim_node->cast<ValueNodePtr>();
if (prim_value_node == nullptr) {
MS_LOG(ERROR) << "cnode first input is not valueNode.";
return lite::RET_ERROR;
}
auto prim_value = prim_value_node->value();
if (prim_value == nullptr) {
MS_LOG(ERROR) << "valueNode value is nullptr.";
return lite::RET_ERROR;
}
auto prim_c = prim_value->cast<PrimitiveCPtr>();
if (prim_c == nullptr) {
MS_LOG(ERROR) << "value is not primitiveC.";
return lite::RET_ERROR;
}
auto prim = prim_c->primitiveT();
MS_ASSERT(prim != nullptr && prim->value.value != nullptr);
auto attr = prim->value.value;
if (prim->value.type == schema::PrimitiveType_ArgMax) {
reinterpret_cast<schema::ArgMaxT *>(attr)->outMaxValue = index_or_value;
} else if (prim->value.type == schema::PrimitiveType_ArgMin) {
reinterpret_cast<schema::ArgMinT *>(attr)->outMaxValue = index_or_value;
}
return lite::RET_OK;
}

int MindirInputAdjustOpPass::CopyPrimitiveCForArgMinMax(std::vector<AnfNodePtr> *inputs) {
MS_ASSERT(inputs != nullptr);
auto prim_node = inputs->at(0);
MS_ASSERT(prim_node != nullptr);
auto prim_value_node = prim_node->cast<ValueNodePtr>();
if (prim_value_node == nullptr) {
MS_LOG(ERROR) << "cnode first input is not valueNode.";
return lite::RET_ERROR;
}
auto prim_value = prim_value_node->value();
if (prim_value == nullptr) {
MS_LOG(ERROR) << "valueNode value is nullptr.";
return lite::RET_ERROR;
}
auto prim_c = prim_value->cast<PrimitiveCPtr>();
if (prim_c == nullptr) {
MS_LOG(ERROR) << "value is not primitiveC.";
return lite::RET_ERROR;
}
auto prim = prim_c->primitiveT();
MS_ASSERT(prim != nullptr && prim->value.value != nullptr);
auto primitive = std::make_unique<schema::PrimitiveT>();
if (prim->value.type == schema::PrimitiveType_ArgMax) {
primitive->value.type = schema::PrimitiveType_ArgMax;
auto attr = std::make_unique<schema::ArgMaxT>();
CopyAttrForArgMinMax<schema::ArgMaxT>(attr.get(), reinterpret_cast<schema::ArgMaxT *>(prim->value.value));
primitive->value.value = attr.release();
} else {
primitive->value.type = schema::PrimitiveType_ArgMin;
auto attr = std::make_unique<schema::ArgMinT>();
CopyAttrForArgMinMax<schema::ArgMinT>(attr.get(), reinterpret_cast<schema::ArgMinT *>(prim->value.value));
primitive->value.value = attr.release();
}
auto primitive_c = PrimitiveC::Create(primitive.release());
auto value_node = NewValueNode(std::shared_ptr<PrimitiveC>(primitive_c));
inputs->erase(inputs->begin());
inputs->insert(inputs->begin(), value_node);
return lite::RET_OK;
}

int MindirInputAdjustOpPass::BuildCNodeForArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item,
const CNodePtr &argmin_max) {
MS_ASSERT(graph != nullptr && tuple_get_item != nullptr && argmin_max != nullptr);
auto inputs = argmin_max->inputs();
if (CopyPrimitiveCForArgMinMax(&inputs) != lite::RET_OK) {
MS_LOG(ERROR) << "copy argmin or argmax failed.";
return lite::RET_ERROR;
}
if (AdjustArgMinMaxInputs(&inputs, false) != lite::RET_OK) {
MS_LOG(ERROR) << "adjust argmin or argmax attr failed.";
return lite::RET_ERROR;
}
auto new_cnode = graph->NewCNode(inputs);
new_cnode->set_fullname_with_scope(argmin_max->fullname_with_scope() + "_index");
auto type_ptr = TypeIdToType(kTypeUnknown);
std::vector<int64_t> shape_vector;
new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
auto manager = graph->manager();
MS_ASSERT(manager != nullptr);
manager->Replace(tuple_get_item, new_cnode);
return lite::RET_OK;
}

int MindirInputAdjustOpPass::AdjustArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item,
const CNodePtr &argmin_max) {
MS_ASSERT(graph != nullptr && tuple_get_item != nullptr && argmin_max != nullptr);
auto inputs = argmin_max->inputs();
if (AdjustArgMinMaxInputs(&inputs, true) != lite::RET_OK) {
MS_LOG(ERROR) << "adjust argmin or argmax attr failed.";
return lite::RET_ERROR;
}
auto type_ptr = TypeIdToType(kTypeUnknown);
std::vector<int64_t> shape_vector;
auto abtract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
argmin_max->set_abstract(abtract_tensor);
auto manager = graph->manager();
MS_ASSERT(manager != nullptr);
manager->Replace(tuple_get_item, argmin_max);
return lite::RET_OK;
}

int MindirInputAdjustOpPass::AdjustTupleGetItemWithArgMinMax(const FuncGraphPtr &graph, const CNodePtr &cnode) {
MS_ASSERT(graph != nullptr && cnode != nullptr);
auto inputs = cnode->inputs();
if (inputs.size() != 3) {
MS_LOG(ERROR) << "tupleGetItem inputs size is invalid: " << inputs.size();
return lite::RET_ERROR;
}
auto argmin_max = inputs.at(1);
MS_ASSERT(argmin_max != nullptr);
auto argmin_max_cnode = argmin_max->cast<CNodePtr>();
if (argmin_max_cnode == nullptr) {
MS_LOG(ERROR) << "the second input is not a cnode.";
return lite::RET_ERROR;
}
if (!CheckCNodeIsArgMinMax(argmin_max_cnode)) {
MS_LOG(DEBUG) << "tuple_get_item first input is not argmin and argmax.";
return lite::RET_OK;
}
auto index_vnode = inputs.at(2);
auto value_node = index_vnode->cast<ValueNodePtr>();
if (value_node == nullptr) {
MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
return lite::RET_ERROR;
}
int index = lite::CastToInt(value_node->value()).front();
if (index == 0) {
if (BuildCNodeForArgMinMax(graph, cnode, argmin_max_cnode) != lite::RET_OK) {
MS_LOG(ERROR) << "build new cnode failed.";
return lite::RET_ERROR;
}
} else if (index == 1) {
if (AdjustArgMinMax(graph, cnode, argmin_max_cnode) != lite::RET_OK) {
MS_LOG(ERROR) << "adjust argmin_max failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}

bool MindirInputAdjustOpPass::Run(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto manager = Manage(graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return lite::RET_NULL_PTR;
}
auto node_list = TopoSort(graph->get_return());
int status = lite::RET_OK;
for (auto &node : node_list) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
MS_LOG(DEBUG) << "node is not cnode.";
continue;
}
auto type = opt::GetCNodeType(node);
if (type == schema::PrimitiveType_TupleGetItem) {
status = AdjustTupleGetItemWithArgMinMax(graph, cnode);
}
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "adjust input pass is failed.";
return false;
}
}
return true;
}
} // namespace opt
} // namespace mindspore

+ 41
- 0
mindspore/lite/tools/optimizer/graph/mindir_inputs_adjust_pass.h View File

@@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_INPUTS_ADJUST_PASS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_INPUTS_ADJUST_PASS_H_

#include <string>
#include <vector>
#include "backend/optimizer/common/pass.h"
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "src/param_value_lite.h"

namespace mindspore::opt {
class MindirInputAdjustOpPass : public Pass {
public:
MindirInputAdjustOpPass() : Pass("mindir_inputs_adjust_pass") {}
~MindirInputAdjustOpPass() override = default;
bool CheckCNodeIsArgMinMax(const CNodePtr &cnode);
int AdjustArgMinMaxInputs(std::vector<AnfNodePtr> *inputs, bool index_or_value);
int CopyPrimitiveCForArgMinMax(std::vector<AnfNodePtr> *inputs);
int BuildCNodeForArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item, const CNodePtr &argmin_max);
int AdjustArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item, const CNodePtr &argmin_max);
int AdjustTupleGetItemWithArgMinMax(const FuncGraphPtr &graph, const CNodePtr &cnode);
bool Run(const FuncGraphPtr &graph) override;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_INPUTS_ADJUST_PASS_H_

+ 0
- 2
mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h View File

@@ -43,8 +43,6 @@ class OnnxInputAdjustOpPass : public Pass {
STATUS ReplaceConstant(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS ReplaceTransposeWithGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
bool Run(const FuncGraphPtr &func_graph) override;

private:
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_ONNX_INPUTS_ADJUST_PASS_H_

Loading…
Cancel
Save