Browse Source

!13296 [MS][LITE] Support fl bert convert

From: @cjh9368
Reviewed-by: @zhanghaibo5,@hangangqiang
Signed-off-by: @hangangqiang
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
8e582869b9
9 changed files with 76 additions and 14 deletions
  1. +1
    -1
      mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.h
  2. +0
    -3
      mindspore/lite/src/ops/ops_utils.cc
  3. +3
    -1
      mindspore/lite/src/ops/populate/arithmetic_populate.cc
  4. +1
    -0
      mindspore/lite/src/ops/populate/common_populate.cc
  5. +45
    -0
      mindspore/lite/src/ops/populate/strided_slice_grad_populate.cc
  6. +19
    -4
      mindspore/lite/tools/anf_exporter/anf_exporter.cc
  7. +2
    -0
      mindspore/lite/tools/anf_exporter/anf_exporter.h
  8. +2
    -2
      mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc
  9. +3
    -3
      mindspore/lite/tools/optimizer/graph/primitive_adjust_pass.cc

+ 1
- 1
mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.


+ 0
- 3
mindspore/lite/src/ops/ops_utils.cc View File

@@ -668,12 +668,10 @@ schema::PrimitiveT *StridedSlicePrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::StridedSlice>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}

schema::PrimitiveT *StridedSliceGradPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::StridedSliceGrad>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}

schema::PrimitiveT *SubFusionPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::SubFusion>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
@@ -834,7 +832,6 @@ RegistryMSOps g_fullConnectionPrimitiveCreatorRegistry("FullConnection", FullCon
RegistryMSOps g_fusedBatchNormPrimitiveCreatorRegistry("FusedBatchNorm", FusedBatchNormPrimitiveCreator);
RegistryMSOps g_gatherPrimitiveCreatorRegistry("Gather", GatherPrimitiveCreator);
RegistryMSOps g_gatherNdPrimitiveCreatorRegistry("GatherNd", GatherNdPrimitiveCreator);
RegistryMSOps g_geluGradPrimitiveCreatorRegistry("GeluGrad", ActivationGradPrimitiveCreator);
RegistryMSOps g_greaterPrimitiveCreatorRegistry("Greater", GreaterPrimitiveCreator);
RegistryMSOps g_greaterEqualPrimitiveCreatorRegistry("GreaterEqual", GreaterEqualPrimitiveCreator);
RegistryMSOps g_gRUPrimitiveCreatorRegistry("GRU", GRUPrimitiveCreator);


+ 3
- 1
mindspore/lite/src/ops/populate/arithmetic_populate.cc View File

@@ -42,8 +42,10 @@ OpParameter *PopulateArithmetic(const void *primitive) {
return reinterpret_cast<OpParameter *>(param);
}

Registry g_MinimunGradParameterRegistry(schema::PrimitiveType_MinimumGrad, PopulateArithmetic, SCHEMA_CUR);
Registry g_MaximunGradParameterRegistry(schema::PrimitiveType_MaximumGrad, PopulateArithmetic, SCHEMA_CUR);
Registry g_realDivParameterRegistry(schema::PrimitiveType_RealDiv, PopulateArithmetic, SCHEMA_CUR);
Registry g_ogicalAndParameterRegistry(schema::PrimitiveType_LogicalAnd, PopulateArithmetic, SCHEMA_CUR);
Registry g_logicalAndParameterRegistry(schema::PrimitiveType_LogicalAnd, PopulateArithmetic, SCHEMA_CUR);
Registry g_parameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic, SCHEMA_CUR);
Registry g_equalParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic, SCHEMA_CUR);
Registry g_notEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic, SCHEMA_CUR);


+ 1
- 0
mindspore/lite/src/ops/populate/common_populate.cc View File

@@ -32,5 +32,6 @@ OpParameter *PopulateCommonParameter(const void *prim) {
} // namespace

Registry g_zerosLikeParameterRegistry(schema::PrimitiveType_ZerosLike, PopulateCommonParameter, SCHEMA_CUR);
Registry g_dependParameterRegistry(schema::PrimitiveType_Depend, PopulateCommonParameter, SCHEMA_CUR);
} // namespace lite
} // namespace mindspore

+ 45
- 0
mindspore/lite/src/ops/populate/strided_slice_grad_populate.cc View File

@@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/populate/populate_register.h"
#include "nnacl/strided_slice_parameter.h"

namespace mindspore {
namespace lite {
OpParameter *PopulateStridedSliceGradParameter(const void *prim) {
auto *strided_slice_param = reinterpret_cast<StridedSliceParameter *>(malloc(sizeof(StridedSliceParameter)));
if (strided_slice_param == nullptr) {
MS_LOG(ERROR) << "malloc StridedSliceParameter failed.";
return nullptr;
}
memset(strided_slice_param, 0, sizeof(StridedSliceParameter));

auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_StridedSliceGrad();
strided_slice_param->op_parameter_.type_ = primitive->value_type();

strided_slice_param->begins_mask_ = value->begin_mask();
strided_slice_param->ends_mask_ = value->end_mask();
strided_slice_param->ellipsisMask_ = value->ellipsis_mask();
strided_slice_param->newAxisMask_ = value->new_axis_mask();
strided_slice_param->shrinkAxisMask_ = value->shrink_axis_mask();
return reinterpret_cast<OpParameter *>(strided_slice_param);
}

Registry StridedSliceGradParameterRegistry(schema::PrimitiveType_StridedSliceGrad, PopulateStridedSliceGradParameter,
SCHEMA_CUR);

} // namespace lite
} // namespace mindspore

+ 19
- 4
mindspore/lite/tools/anf_exporter/anf_exporter.cc View File

@@ -670,6 +670,23 @@ void AnfExporter::ProcessBoolImm(const ValueNodePtr &valueNode, std::unique_ptr<
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(*paramTensor));
}
int AnfExporter::ProcessNumber(const ValueNodePtr &valueNode, schema::TensorT *paramTensor,
schema::CNodeT *output_cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
auto data = valueNode->value()->cast<NumberPtr>();
paramTensor->data.resize(sizeof(int));
int number_type = data->number_type();
if (EOK != ::memcpy_s(paramTensor->data.data(), sizeof(int), &number_type, sizeof(int))) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_MEMORY_FAILED;
}
paramTensor->dataType = kNumberTypeInt32;
paramTensor->dims = {1};
paramTensor->nodeType = schema::NodeType_ValueNode;
node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(paramTensor);
return RET_OK;
}
void AnfExporter::ProcessInt(const ValueNodePtr &valueNode, std::unique_ptr<schema::TensorT> *paramTensor,
schema::CNodeT *output_cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
(*paramTensor)->dataType = kNumberTypeInt32;
@@ -765,8 +782,7 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano
} else if (value->isa<mindspore::ValueSequeue>()) {
ret = ProcessValueSequence(valueNode, &paramTensor, value, output_cnode, meta_graphT);
} else if (value->isa<Number>()) {
MS_LOG(INFO) << "Value is a number.";
return RET_OK;
ret = ProcessNumber(valueNode, paramTensor.release(), output_cnode, meta_graphT);
} else if (value->isa<mindspore::ParamValueLite>()) {
ret = ProcessParamValueLite(valueNode, &paramTensor, value, output_cnode, meta_graphT);
} else if (value->isa<FuncGraph>()) {
@@ -878,8 +894,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
msTensor->dataType = type;
meta_graphT->allTensors.emplace_back(msTensor);
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
opt::CheckPrimitiveType(cnode, prim::kPrimFusedBatchNorm) ||
opt::CheckPrimitiveType(cnode, prim::kPrimLayerNormFusion)) {
opt::CheckPrimitiveType(cnode, prim::kPrimFusedBatchNorm)) {
break;
}
}


+ 2
- 0
mindspore/lite/tools/anf_exporter/anf_exporter.h View File

@@ -64,6 +64,8 @@ class AnfExporter {
const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
void ProcessInt(const ValueNodePtr &valueNode, std::unique_ptr<schema::TensorT> *paramTensor,
schema::CNodeT *output_cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
int ProcessNumber(const ValueNodePtr &valueNode, schema::TensorT *paramTensor, schema::CNodeT *output_cnode,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
int ProcessValueSequence(const ValueNodePtr &valueNode, std::unique_ptr<schema::TensorT> *paramTensor,
const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT);


+ 2
- 2
mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc View File

@@ -94,7 +94,7 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN
MS_ASSERT(graph->allTensors.size() > mulNodeInputIndex.at(MUL_OP_BIAS_INDEX));
const auto &mulNodeBiasTensor = graph->allTensors.at(mulNodeInputIndex.at(MUL_OP_BIAS_INDEX));
MS_ASSERT(mulNodeBiasTensor != nullptr);
if (mulNodeBiasTensor->refCount != schema::NodeType::NodeType_ValueNode) {
if (mulNodeBiasTensor->nodeType != schema::NodeType::NodeType_ValueNode) {
// dont fusion, return
return RET_OK;
}
@@ -111,7 +111,7 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN
MS_ASSERT(graph->allTensors.size() > addNodeInputIndex.at(ADD_OP_BIAS_INDEX));
const auto &addNodeBiasTensor = graph->allTensors.at(addNodeInputIndex.at(ADD_OP_BIAS_INDEX));
MS_ASSERT(addNodeBiasTensor != nullptr);
if (addNodeBiasTensor->refCount != schema::NodeType::NodeType_ValueNode) {
if (addNodeBiasTensor->nodeType != schema::NodeType::NodeType_ValueNode) {
// dont fusion, return
return RET_OK;
}


+ 3
- 3
mindspore/lite/tools/optimizer/graph/primitive_adjust_pass.cc View File

@@ -134,7 +134,7 @@ constexpr auto kNameReluGrad = "ReluGrad";
constexpr auto kNameReLU6Grad = "ReLU6Grad";
constexpr auto kNameSigmoidGrad = "SigmoidGrad";
constexpr auto kNameEluGrad = "EluGrad";
constexpr auto kNameGeluGrad = "GeluGrad";
constexpr auto kNameGeLUGrad = "GeLUGrad";
constexpr auto kNameSlice = "Slice";
constexpr auto kNameAvgPoolGradGpu = "AvgPoolGradGpu";
constexpr auto kNameAvgPoolGradCpu = "AvgPoolGradCpu";
@@ -155,7 +155,7 @@ std::map<std::string, mindspore::ActivationType> activation_map = {{ops::kNameEl
{kNameReLU6Grad, mindspore::RELU6},
{kNameSigmoidGrad, mindspore::SIGMOID},
{kNameEluGrad, mindspore::ELU},
{kNameGeluGrad, mindspore::GELU},
{kNameGeLUGrad, mindspore::GELU},
{kNameTanhGrad, mindspore::TANH}};

std::map<std::string, mindspore::ReduceMode> reduce_map = {
@@ -550,7 +550,7 @@ REGIST_PRIMITIVE_ADJUST(kNameExp, MoveAttrMapCommon<ops::ExpFusion>)
REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormEx, MoveAttrMapCommon<ops::FusedBatchNorm>)
REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradEx, MoveAttrMapCommon<ops::BatchNormGrad>)
REGIST_PRIMITIVE_ADJUST(kNameGeLU, MoveAttrMapActivation)
REGIST_PRIMITIVE_ADJUST(kNameGeluGrad, MoveAttrMapActivationGrad)
REGIST_PRIMITIVE_ADJUST(kNameGeLUGrad, MoveAttrMapActivationGrad)
REGIST_PRIMITIVE_ADJUST(kNameHSigmoid, MoveAttrMapActivation)
REGIST_PRIMITIVE_ADJUST(kNameHSigmoidGrad, MoveAttrMapActivationGrad)
REGIST_PRIMITIVE_ADJUST(kNameHSwish, MoveAttrMapActivation)


Loading…
Cancel
Save