Browse Source

!7550 add layer_norm op and layer_norm fusion for transformer model

Merge pull request !7550 from wangzhe/master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
155d2bb6df
16 changed files with 857 additions and 0 deletions
  1. +51
    -0
      mindspore/lite/nnacl/fp32/layer_norm.c
  2. +33
    -0
      mindspore/lite/nnacl/fp32/layer_norm.h
  3. +29
    -0
      mindspore/lite/nnacl/layer_norm_parameter.h
  4. +1
    -0
      mindspore/lite/schema/model.fbs
  5. +7
    -0
      mindspore/lite/schema/ops.fbs
  6. +105
    -0
      mindspore/lite/src/ops/layer_norm.cc
  7. +50
    -0
      mindspore/lite/src/ops/layer_norm.h
  8. +3
    -0
      mindspore/lite/src/ops/primitive_c.cc
  9. +24
    -0
      mindspore/lite/src/populate_parameter.cc
  10. +111
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm.cc
  11. +52
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm.h
  12. +1
    -0
      mindspore/lite/test/CMakeLists.txt
  13. +1
    -0
      mindspore/lite/tools/converter/CMakeLists.txt
  14. +2
    -0
      mindspore/lite/tools/converter/anf_transform.cc
  15. +332
    -0
      mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.cc
  16. +55
    -0
      mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.h

+ 51
- 0
mindspore/lite/nnacl/fp32/layer_norm.c View File

@@ -0,0 +1,51 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/layer_norm.h"
#include <math.h>
#include "nnacl/errorcode.h"
#include "nnacl/op_base.h"

int LayerNorm(const int outer_size, const int inner_size, const float *src_data, const float *gamma_data,
const float *beta_data, const bool affine, const float epsilon, float *dst_data, const int tid,
const int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
if (affine && (gamma_data == NULL || beta_data == NULL)) {
return NNACL_NULL_PTR;
}
int i, j;
for (j = tid; j < outer_size; j += thread_num) {
const float *src = src_data + j * inner_size;
float *dst = dst_data + j * inner_size;
float mean = 0.0f;
float square_mean = 0.0f;
for (i = 0; i < inner_size; i++) {
mean += src[i];
square_mean += src[i] * src[i];
}
mean /= (float)inner_size;
square_mean /= (float)inner_size;
float deno = 1 / sqrtf(square_mean - mean * mean + epsilon);
for (i = 0; i < inner_size; ++i) {
dst[i] = (src[i] - mean) * deno;
if (affine) {
dst[i] = dst[i] * gamma_data[i] + beta_data[i];
}
}
}
return NNACL_OK;
}

+ 33
- 0
mindspore/lite/nnacl/fp32/layer_norm.h View File

@@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_H_
#define MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_H_

#include "nnacl/op_base.h"
#include "nnacl/layer_norm_parameter.h"

#ifdef __cplusplus
extern "C" {
#endif

int LayerNorm(const int outer_size, const int inner_size, const float *src_data, const float *gamma_data,
const float *beta_data, const bool affine, const float epsilon, float *dst_data, const int tid,
const int thread_num);
#ifdef __cplusplus
}
#endif

#endif // MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_H_

+ 29
- 0
mindspore/lite/nnacl/layer_norm_parameter.h View File

@@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_LAYER_NORM_PARAMETER_H_
#define MINDSPORE_LITE_NNACL_LAYER_NORM_PARAMETER_H_

#include "nnacl/op_base.h"

typedef struct LayerNormParameter {
OpParameter op_parameter_;
int *normalized_shape_;
int normalized_dims_;
float epsilon_;
bool elementwise_affine_;
} LayerNormParameter;

#endif // MINDSPORE_LITE_NNACL_LAYER_NORM_PARAMETER_H_

+ 1
- 0
mindspore/lite/schema/model.fbs View File

@@ -223,6 +223,7 @@ union PrimitiveType {
NonMaxSuppression, NonMaxSuppression,
InstanceNorm, InstanceNorm,
Identity, Identity,
LayerNorm,
} }


enum QuantType: int { enum QuantType: int {


+ 7
- 0
mindspore/lite/schema/ops.fbs View File

@@ -1097,3 +1097,10 @@ table Loop {


table Identity { table Identity {
} }

table LayerNorm {
normalizedShape : [int];
epsilon : float = 0.00001;
elementwiseAffine : bool;
}


+ 105
- 0
mindspore/lite/src/ops/layer_norm.cc View File

@@ -0,0 +1,105 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/layer_norm.h"

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> LayerNorm::GetNormalizedShape() const {
return this->primitive_->value.AsLayerNorm()->normalizedShape;
}
float LayerNorm::GetEpsilon() const { return this->primitive_->value.AsLayerNorm()->epsilon; }
bool LayerNorm::GetElementwiseAffine() const { return this->primitive_->value.AsLayerNorm()->elementwiseAffine; }

void LayerNorm::SetNormalizedShape(const std::vector<int> &normalizedShape) {
this->primitive_->value.AsLayerNorm()->normalizedShape = normalizedShape;
}
void LayerNorm::SetEpsilon(float epsilon) { this->primitive_->value.AsLayerNorm()->epsilon = epsilon; }
void LayerNorm::SetElementwiseAffine(bool elementwiseAffine) {
this->primitive_->value.AsLayerNorm()->elementwiseAffine = elementwiseAffine;
}

#else
int LayerNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_LayerNorm();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_LayerNorm return nullptr";
return RET_ERROR;
}

std::vector<int32_t> normalizedShape;
if (attr->normalizedShape() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->normalizedShape()->size()); i++) {
normalizedShape.push_back(attr->normalizedShape()->data()[i]);
}
}
auto val_offset = schema::CreateLayerNormDirect(*fbb, &normalizedShape, attr->epsilon(), attr->elementwiseAffine());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_LayerNorm, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
std::vector<int> LayerNorm::GetNormalizedShape() const {
auto fb_vector = this->primitive_->value_as_LayerNorm()->normalizedShape();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
float LayerNorm::GetEpsilon() const { return this->primitive_->value_as_LayerNorm()->epsilon(); }
bool LayerNorm::GetElementwiseAffine() const { return this->primitive_->value_as_LayerNorm()->elementwiseAffine(); }

#endif
int LayerNorm::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
if (outputs_.size() != kSingleNum || (inputs_.size() != kSingleNum && inputs_.size() != kMultiNum)) {
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs_.size() << ",input size: " << inputs_.size();
return RET_PARAM_INVALID;
}
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.at(0);
MS_ASSERT(output != nullptr);
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());

if (GetElementwiseAffine() && inputs_.size() != kMultiNum) {
MS_LOG(INFO) << "input tensor amount error";
return RET_INPUT_TENSOR_ERROR;
}
if (!GetElementwiseAffine() && inputs_.size() != kSingleNum) {
MS_LOG(INFO) << "input tensor amount error";
return RET_INPUT_TENSOR_ERROR;
}
auto input_shape = input->shape();
auto normalized_shape = GetNormalizedShape();
if (normalized_shape.size() > input_shape.size() || normalized_shape.size() == 0) {
MS_LOG(INFO) << "normalized_shape attr invalid";
return RET_PARAM_INVALID;
}
size_t first_index = input_shape.size() - normalized_shape.size();
for (size_t i = first_index; i < input_shape.size(); ++i) {
if (input_shape[i] != normalized_shape[i - first_index]) {
MS_LOG(INFO) << "normalized_shape attr invalid";
return RET_PARAM_INVALID;
}
}
if (!GetInferFlag()) {
return RET_OK;
}

output->set_shape(input_shape);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 50
- 0
mindspore/lite/src/ops/layer_norm.h View File

@@ -0,0 +1,50 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_OPS_LAYER_NORM_H_
#define MINDSPORE_LITE_SRC_OPS_LAYER_NORM_H_

#include <vector>
#include <set>
#include <cmath>
#include <memory>

#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
class LayerNorm : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(LayerNorm, PrimitiveC);
LayerNorm() = default;
explicit LayerNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetNormalizedShape(const std::vector<int> &normalizedShape);
void SetEpsilon(float epsilon);
void SetElementwiseAffine(bool elementwiseAffine);
#else
LayerNorm() = default;

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
std::vector<int> GetNormalizedShape() const;
float GetEpsilon() const;
bool GetElementwiseAffine() const;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_SRC_OPS_LAYER_NORM_H_

+ 3
- 0
mindspore/lite/src/ops/primitive_c.cc View File

@@ -135,6 +135,7 @@
#include "src/ops/custom_normalize.h" #include "src/ops/custom_normalize.h"
#include "src/ops/custom_extract_features.h" #include "src/ops/custom_extract_features.h"
#include "src/ops/upsample.h" #include "src/ops/upsample.h"
#include "src/ops/layer_norm.h"


#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h" #include "src/ops/neg_grad.h"
@@ -723,6 +724,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new CustomExtractFeatures(primitive); return new CustomExtractFeatures(primitive);
case schema::PrimitiveType_Upsample: case schema::PrimitiveType_Upsample:
return new Upsample(primitive); return new Upsample(primitive);
case schema::PrimitiveType_LayerNorm:
return new LayerNorm(primitive);


#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad: case schema::PrimitiveType_ActivationGrad:


+ 24
- 0
mindspore/lite/src/populate_parameter.cc View File

@@ -120,6 +120,7 @@
#include "src/ops/detection_post_process.h" #include "src/ops/detection_post_process.h"
#include "src/ops/skip_gram.h" #include "src/ops/skip_gram.h"
#include "src/ops/custom_predict.h" #include "src/ops/custom_predict.h"
#include "src/ops/layer_norm.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "nnacl/fp32/arg_min_max.h" #include "nnacl/fp32/arg_min_max.h"
#include "nnacl/fp32/cast.h" #include "nnacl/fp32/cast.h"
@@ -183,6 +184,7 @@
#include "nnacl/fp32/exp.h" #include "nnacl/fp32/exp.h"
#include "nnacl/fp32/skip_gram.h" #include "nnacl/fp32/skip_gram.h"
#include "nnacl/predict_parameter.h" #include "nnacl/predict_parameter.h"
#include "nnacl/layer_norm_parameter.h"


namespace mindspore::kernel { namespace mindspore::kernel {


@@ -1674,6 +1676,27 @@ OpParameter *PopulateCustomPredictParameter(const mindspore::lite::PrimitiveC *p
return reinterpret_cast<OpParameter *>(param); return reinterpret_cast<OpParameter *>(param);
} }


OpParameter *PopulateLayerNormParameter(const mindspore::lite::PrimitiveC *primitive) {
auto layer_norm_parameter = reinterpret_cast<LayerNormParameter *>(malloc(sizeof(LayerNormParameter)));
if (layer_norm_parameter == nullptr) {
MS_LOG(ERROR) << "malloc LayerNormParameter failed.";
return nullptr;
}
memset(layer_norm_parameter, 0, sizeof(LayerNormParameter));
layer_norm_parameter->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::LayerNorm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
auto normalized_shape = param->GetNormalizedShape();
layer_norm_parameter->normalized_dims_ = normalized_shape.size();
layer_norm_parameter->normalized_shape_ = reinterpret_cast<int *>(malloc(normalized_shape.size() * sizeof(int)));
for (size_t i = 0; i < normalized_shape.size(); i++) {
layer_norm_parameter->normalized_shape_[i] = normalized_shape[i];
}
layer_norm_parameter->epsilon_ = param->GetEpsilon();
layer_norm_parameter->elementwise_affine_ = param->GetElementwiseAffine();

return reinterpret_cast<OpParameter *>(layer_norm_parameter);
}

PopulateParameterRegistry::PopulateParameterRegistry() { PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_SparseToDense] = PopulateSparseToDenseParameter; populate_parameter_funcs_[schema::PrimitiveType_SparseToDense] = PopulateSparseToDenseParameter;
populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter; populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter;
@@ -1784,6 +1807,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_LshProjection] = PopulateLshProjectionParameter; populate_parameter_funcs_[schema::PrimitiveType_LshProjection] = PopulateLshProjectionParameter;
populate_parameter_funcs_[schema::PrimitiveType_CustomPredict] = PopulateCustomPredictParameter; populate_parameter_funcs_[schema::PrimitiveType_CustomPredict] = PopulateCustomPredictParameter;
populate_parameter_funcs_[schema::PrimitiveType_HashtableLookup] = PopulateCommonOpParameter; populate_parameter_funcs_[schema::PrimitiveType_HashtableLookup] = PopulateCommonOpParameter;
populate_parameter_funcs_[schema::PrimitiveType_LayerNorm] = PopulateLayerNormParameter;
} }


PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() {


+ 111
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm.cc View File

@@ -0,0 +1,111 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/layer_norm.h"
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_LayerNorm;

namespace mindspore::kernel {
int LayerNormCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int LayerNormCPUKernel::ReSize() {
auto shape = in_tensors_.front()->shape();
outer_size_ = 1;
inner_size_ = 1;
for (size_t i = 0; i < shape.size(); ++i) {
if (i + param_->normalized_dims_ < shape.size()) {
outer_size_ *= shape[i];
} else {
inner_size_ *= shape[i];
}
}
return RET_OK;
}

int LayerNormCPUKernel::DoLayerNorm(int thread_id) {
int ret = LayerNorm(outer_size_, inner_size_, src_data_, gamma_data_, beta_data_, param_->elementwise_affine_,
param_->epsilon_, dst_data_, thread_id, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoLayerNorm error error_code[" << ret << "]";
return ret;
}
return RET_OK;
}

int LayerNormRun(void *cdata, int task_id) {
auto LayerNormData = reinterpret_cast<LayerNormCPUKernel *>(cdata);
auto ret = LayerNormData->DoLayerNorm(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "LayerNormRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}

int LayerNormCPUKernel::Run() {
src_data_ = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
if (param_->elementwise_affine_) {
gamma_data_ = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
beta_data_ = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData());
}
dst_data_ = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
auto ret = ParallelLaunch(this->context_->thread_pool_, LayerNormRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "FillRun error error_code[" << ret << "]";
return ret;
}
return RET_OK;
}

kernel::LiteKernel *CpuLayerNormFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Create kernel failed, opParameter is nullptr, type: PrimitiveType_LayerNorm. ";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_LayerNorm);
auto *kernel = new (std::nothrow) LayerNormCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new LayerNormCPUKernel fail!";
free(opParameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LayerNorm, CpuLayerNormFp32KernelCreator)
} // namespace mindspore::kernel

+ 52
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm.h View File

@@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LAYER_NORM_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LAYER_NORM_H_
#include <vector>
#include "src/lite_kernel.h"
#include "include/context.h"
#include "nnacl/fp32/layer_norm.h"

using mindspore::lite::InnerContext;

namespace mindspore::kernel {
class LayerNormCPUKernel : public LiteKernel {
public:
LayerNormCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param_ = reinterpret_cast<LayerNormParameter *>(parameter);
}
~LayerNormCPUKernel() override{};

int Init() override;
int ReSize() override;
int Run() override;
int DoLayerNorm(int thread_id);

private:
LayerNormParameter *param_;
int outer_size_;
int inner_size_;
float *src_data_ = nullptr;
float *dst_data_ = nullptr;
float *gamma_data_ = nullptr;
float *beta_data_ = nullptr;
};
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LAYER_NORM_H_

+ 1
- 0
mindspore/lite/test/CMakeLists.txt View File

@@ -178,6 +178,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/layer_norm_fusion.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc


+ 1
- 0
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -39,6 +39,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/conv_bn_fusion.cc ../optimizer/fusion/conv_bn_fusion.cc
../optimizer/fusion/constant_folding_fusion.cc ../optimizer/fusion/constant_folding_fusion.cc
../optimizer/fusion/quant_dtype_cast_fusion.cc ../optimizer/fusion/quant_dtype_cast_fusion.cc
../optimizer/fusion/layer_norm_fusion.cc
../optimizer/graph/weight_format_transform_pass.cc ../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc ../optimizer/graph/weight_format_hardcode_pass.cc
../optimizer/graph/clip_convert_activation_pass.cc ../optimizer/graph/clip_convert_activation_pass.cc


+ 2
- 0
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -25,6 +25,7 @@
#include "tools/optimizer/fusion/conv_bn_fusion.h" #include "tools/optimizer/fusion/conv_bn_fusion.h"
#include "tools/optimizer/fusion/constant_folding_fusion.h" #include "tools/optimizer/fusion/constant_folding_fusion.h"
#include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h"
#include "tools/optimizer/fusion/layer_norm_fusion.h"
#include "tools/optimizer/graph/weight_format_hardcode_pass.h" #include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/optimizer/graph/weight_format_transform_pass.h" #include "tools/optimizer/graph/weight_format_transform_pass.h"
#include "tools/optimizer/graph/clip_convert_activation_pass.h" #include "tools/optimizer/graph/clip_convert_activation_pass.h"
@@ -57,6 +58,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>()); pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>()); pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>());
pm->AddPass(std::make_shared<opt::ConvScaleFusion>()); pm->AddPass(std::make_shared<opt::ConvScaleFusion>());
pm->AddPass(std::make_shared<opt::LayerNormFusion>());
pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu", schema::PrimitiveType_Activation, pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu", schema::PrimitiveType_Activation,
schema::ActivationType_RELU)); schema::ActivationType_RELU));
pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu6", schema::PrimitiveType_Activation, pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu6", schema::PrimitiveType_Activation,


+ 332
- 0
mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.cc View File

@@ -0,0 +1,332 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/layer_norm_fusion.h"
#include <memory>
#include "src/ops/primitive_c.h"
#include "src/param_value_lite.h"
#include "schema/inner/model_generated.h"
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "securec/include/securec.h"
#include "src/ops/add.h"
#include "src/ops/mul.h"
#include "src/ops/rsqrt.h"
#include "src/ops/reduce.h"
#include "src/ops/sub.h"

namespace mindspore {
namespace opt {
namespace {
constexpr size_t kAddInputsLength = 3;
constexpr size_t kSubInputsLength = 3;
constexpr size_t kMulInputsLength = 3;
constexpr size_t kRsqrtInputsLength = 2;
constexpr size_t kReduceInputsLength = 2;

bool IsAddNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);
return type == schema::PrimitiveType_Add;
}
return false;
}

bool IsSquaredDifferenceNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);
return type == schema::PrimitiveType_SquaredDifference;
}
return false;
}

bool IsReduceNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);
return type == schema::PrimitiveType_Reduce;
}
return false;
}

bool IsRsqrtNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);
return type == schema::PrimitiveType_Rsqrt;
}
return false;
}

bool IsMulNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);
return type == schema::PrimitiveType_Mul;
}
return false;
}

bool IsSubNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);
return type == schema::PrimitiveType_Sub;
}
return false;
}
} // namespace

const BaseRef LayerNormFusion::DefinePattern() const {
auto mean1 = std::make_shared<CondVar>(IsReduceNode);
VectorRef mean1_ref = VectorRef({mean1, input_});
auto squared_diffference1 = std::make_shared<CondVar>(IsSquaredDifferenceNode);
VectorRef squared_diffference1_ref = VectorRef({squared_diffference1, input_, mean1_ref});
auto mul1 = std::make_shared<CondVar>(IsMulNode);
auto mean2 = std::make_shared<CondVar>(IsReduceNode);
VectorRef mean2_ref = VectorRef({mean2, squared_diffference1_ref});
auto add1 = std::make_shared<CondVar>(IsAddNode);
VectorRef add1_ref = VectorRef({add1, mean2_ref, epsilon_});
auto rsqrt1 = std::make_shared<CondVar>(IsRsqrtNode);
VectorRef rsqrt1_ref = VectorRef({rsqrt1, add1_ref});
auto mul2 = std::make_shared<CondVar>(IsMulNode);
VectorRef mul2_ref = VectorRef({mul2, rsqrt1_ref, gamma_});
VectorRef mul1_ref = VectorRef({mul1, input_, mul2_ref});
auto mul3 = std::make_shared<CondVar>(IsMulNode);
VectorRef mul3_ref = VectorRef({mul3, mean1_ref, mul2_ref});
auto sub1 = std::make_shared<CondVar>(IsSubNode);
VectorRef sub1_ref = VectorRef({sub1, beta_, mul3_ref});
auto add2 = std::make_shared<CondVar>(IsAddNode);
VectorRef add2_ref = VectorRef({add2, mul1_ref, sub1_ref});
return add2_ref;
}

CNodePtr LayerNormFusion::CreateLayerNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
const std::vector<int> &shape, const float epsilon) const {
MS_EXCEPTION_IF_NULL(func_graph);
auto layer_norm_primitive = std::make_unique<schema::PrimitiveT>();
std::unique_ptr<schema::LayerNormT> attr = std::make_unique<schema::LayerNormT>();
attr->normalizedShape = shape;
attr->epsilon = epsilon;
attr->elementwiseAffine = true;
layer_norm_primitive->value.type = schema::PrimitiveType_LayerNorm;
layer_norm_primitive->value.value = attr.release();
auto layer_norm_cvalue = lite::PrimitiveC::Create(layer_norm_primitive.release());
auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(layer_norm_cvalue));
std::vector<AnfNodePtr> new_node_inputs = {value_node};
auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
MS_EXCEPTION_IF_NULL(input_node);
new_node_inputs.push_back(input_node);
auto gamma_node = utils::cast<AnfNodePtr>((*equiv)[gamma_]);
MS_EXCEPTION_IF_NULL(gamma_node);
new_node_inputs.push_back(gamma_node);
auto beta_node = utils::cast<AnfNodePtr>((*equiv)[beta_]);
MS_EXCEPTION_IF_NULL(beta_node);
new_node_inputs.push_back(beta_node);
auto new_node = func_graph->NewCNode(new_node_inputs);
return new_node;
}

const AnfNodePtr LayerNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_LOG(DEBUG) << "layer_norm pass";
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}

// add2
auto add2_cnode = node->cast<CNodePtr>();
if (CheckIfCNodeIsNull(add2_cnode) != lite::RET_OK || CheckInputSize(add2_cnode, kAddInputsLength) != lite::RET_OK) {
return nullptr;
}
auto add2_primitivec = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(add2_cnode->input(0));
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Add>>(add2_primitivec));
auto add2_op = utils::cast<std::shared_ptr<mindspore::lite::Add>>(add2_primitivec);
MS_ASSERT(add2_op != nullptr);
AnfNodePtr sub1_node = add2_cnode->input(2);
if (CheckIfAnfNodeIsNull(sub1_node) != lite::RET_OK) {
return nullptr;
}

// sub1
auto sub1_cnode = sub1_node->cast<CNodePtr>();
if (CheckIfCNodeIsNull(sub1_cnode) != lite::RET_OK || CheckInputSize(sub1_cnode, kSubInputsLength) != lite::RET_OK) {
return nullptr;
}
auto sub1_primitivec = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(sub1_cnode->input(0));
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Sub>>(sub1_primitivec));
auto sub1_op = utils::cast<std::shared_ptr<mindspore::lite::Sub>>(sub1_primitivec);
MS_ASSERT(sub1_op != nullptr);
AnfNodePtr beta_node = sub1_cnode->input(1);
AnfNodePtr mul3_node = sub1_cnode->input(2);
if (CheckIfAnfNodeIsNull(beta_node) != lite::RET_OK || CheckIfAnfNodeIsNull(mul3_node) != lite::RET_OK) {
return nullptr;
}

// beta
if (CheckIfNodeIsParam(beta_node) != lite::RET_OK) {
return nullptr;
}
auto beta_param = beta_node->cast<ParameterPtr>()->default_param();
auto beta_tensor = std::dynamic_pointer_cast<ParamValueLite>(beta_param);
auto beta_shape = beta_tensor->tensor_shape();

// mul3
auto mul3_cnode = mul3_node->cast<CNodePtr>();
if (CheckIfCNodeIsNull(mul3_cnode) != lite::RET_OK || CheckInputSize(mul3_cnode, kMulInputsLength) != lite::RET_OK) {
return nullptr;
}
auto mul3_primitivec = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(mul3_cnode->input(0));
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Mul>>(mul3_primitivec));
auto mul3_op = utils::cast<std::shared_ptr<mindspore::lite::Mul>>(mul3_primitivec);
MS_ASSERT(mul3_op != nullptr);
AnfNodePtr mean1_node = mul3_cnode->input(1);
AnfNodePtr mul2_node = mul3_cnode->input(2);
if (CheckIfAnfNodeIsNull(mean1_node) != lite::RET_OK || CheckIfAnfNodeIsNull(mul2_node) != lite::RET_OK) {
return nullptr;
}

// mul2
auto mul2_cnode = mul2_node->cast<CNodePtr>();
if (CheckIfCNodeIsNull(mul2_cnode) != lite::RET_OK || CheckInputSize(mul2_cnode, kMulInputsLength) != lite::RET_OK) {
return nullptr;
}
auto mul2_primitivec = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(mul2_cnode->input(0));
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Mul>>(mul2_primitivec));
auto mul2_op = utils::cast<std::shared_ptr<mindspore::lite::Mul>>(mul2_primitivec);
MS_ASSERT(mul2_op != nullptr);
AnfNodePtr rsqrt_node = mul2_cnode->input(1);
AnfNodePtr gamma_node = mul2_cnode->input(2);
if (CheckIfAnfNodeIsNull(rsqrt_node) != lite::RET_OK || CheckIfAnfNodeIsNull(gamma_node) != lite::RET_OK) {
return nullptr;
}

// gamma
if (CheckIfNodeIsParam(gamma_node) != lite::RET_OK) {
return nullptr;
}
auto gamma_param = gamma_node->cast<ParameterPtr>()->default_param();
auto gamma_tensor = std::dynamic_pointer_cast<ParamValueLite>(gamma_param);
auto gamma_shape = gamma_tensor->tensor_shape();

// rsqrt
auto rsqrt_cnode = rsqrt_node->cast<CNodePtr>();
if (CheckIfCNodeIsNull(rsqrt_cnode) != lite::RET_OK ||
CheckInputSize(rsqrt_cnode, kRsqrtInputsLength) != lite::RET_OK) {
return nullptr;
}
auto rsqrt_primitivec = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(rsqrt_cnode->input(0));
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Rsqrt>>(rsqrt_primitivec));
auto rsqrt_op = utils::cast<std::shared_ptr<mindspore::lite::Rsqrt>>(rsqrt_primitivec);
MS_ASSERT(rsqrt_op != nullptr);
AnfNodePtr add1_node = rsqrt_cnode->input(1);
if (CheckIfAnfNodeIsNull(add1_node) != lite::RET_OK) {
return nullptr;
}

// add1
auto add1_cnode = add1_node->cast<CNodePtr>();
if (CheckIfCNodeIsNull(add1_cnode) != lite::RET_OK || CheckInputSize(add1_cnode, kAddInputsLength) != lite::RET_OK) {
return nullptr;
}
auto add1_primitivec = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(add1_cnode->input(0));
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Add>>(add1_primitivec));
auto add1_op = utils::cast<std::shared_ptr<mindspore::lite::Add>>(add1_primitivec);
MS_ASSERT(add1_op != nullptr);
AnfNodePtr mean2_node = add1_cnode->input(1);
AnfNodePtr epsilon_node = add1_cnode->input(2);
if (CheckIfAnfNodeIsNull(mean2_node) != lite::RET_OK || CheckIfAnfNodeIsNull(epsilon_node) != lite::RET_OK) {
return nullptr;
}

// epsilon
if (CheckIfNodeIsParam(epsilon_node) != lite::RET_OK) {
// delete[] add_bias_data;
return nullptr;
}
auto epsilon_param = epsilon_node->cast<ParameterPtr>()->default_param();
auto epsilon_tensor = std::dynamic_pointer_cast<ParamValueLite>(epsilon_param);
auto epsilon_data = reinterpret_cast<float *>(epsilon_tensor->tensor_addr());
auto epsilon_shape = epsilon_tensor->tensor_shape();

// mean2
auto mean2_cnode = mean2_node->cast<CNodePtr>();
if (CheckIfCNodeIsNull(mean2_cnode) != lite::RET_OK ||
CheckInputSize(mean2_cnode, kReduceInputsLength) != lite::RET_OK) {
return nullptr;
}
auto mean2_primitivec = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(mean2_cnode->input(0));
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Reduce>>(mean2_primitivec));
auto mean2_op = utils::cast<std::shared_ptr<mindspore::lite::Reduce>>(mean2_primitivec);
MS_ASSERT(mean2_op != nullptr);
if (mean2_op->GetMode() != schema::ReduceMode_ReduceMean) {
return nullptr;
}
auto mean2_axes = mean2_op->GetAxes();
AnfNodePtr squared_difference_node = mean2_cnode->input(1);
if (CheckIfAnfNodeIsNull(squared_difference_node) != lite::RET_OK) {
return nullptr;
}

// mean1
auto mean1_cnode = mean1_node->cast<CNodePtr>();
if (CheckIfCNodeIsNull(mean1_cnode) != lite::RET_OK ||
CheckInputSize(mean1_cnode, kReduceInputsLength) != lite::RET_OK) {
return nullptr;
}
auto mean1_primitivec = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(mean1_cnode->input(0));
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Reduce>>(mean1_primitivec));
auto mean1_op = utils::cast<std::shared_ptr<mindspore::lite::Reduce>>(mean1_primitivec);
MS_ASSERT(mean1_op != nullptr);
if (mean1_op->GetMode() != schema::ReduceMode_ReduceMean) {
return nullptr;
}
AnfNodePtr input3_node = mean1_cnode->input(1);
auto mean1_axes = mean1_op->GetAxes();
if (CheckIfAnfNodeIsNull(input3_node) != lite::RET_OK) {
return nullptr;
}

// verify two mean ops have same axes
if (mean1_axes.size() != mean2_axes.size()) {
return nullptr;
}
for (size_t i = 0; i < mean1_axes.size(); ++i) {
if (mean1_axes[i] != mean2_axes[i]) {
return nullptr;
}
}
// verify axes size and gamma/beta size are equal
if (mean1_axes.size() != gamma_shape.size() || mean1_axes.size() != beta_shape.size()) {
return nullptr;
}
// verify gamma and beta have same shape
for (size_t i = 0; i < gamma_shape.size(); ++i) {
if (gamma_shape[i] != beta_shape[i]) {
return nullptr;
}
}
// verify epsilon has exactly one element
float epsilon;
if (epsilon_shape.empty() || (epsilon_shape.size() == 1 && epsilon_shape[0] == 1)) {
epsilon = epsilon_data[0];
} else {
return nullptr;
}

auto layer_norm_cnode = CreateLayerNormNode(func_graph, equiv, gamma_shape, epsilon);
layer_norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope());
MS_LOG(INFO) << "layernorm node:" << layer_norm_cnode->fullname_with_scope() << " fusion success";
return layer_norm_cnode;
}
} // namespace opt
} // namespace mindspore

+ 55
- 0
mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.h View File

@@ -0,0 +1,55 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_

#include <vector>
#include <memory>
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "utils/utils.h"

namespace mindspore {
namespace opt {

class LayerNormFusion : public PatternProcessPass {
public:
explicit LayerNormFusion(const std::string &name = "layer_norm_fusion", bool multigraph = true)
: PatternProcessPass(name, multigraph) {
input_ = std::make_shared<Var>();
gamma_ = std::make_shared<Var>();
beta_ = std::make_shared<Var>();
epsilon_ = std::make_shared<Var>();
}

~LayerNormFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

private:
CNodePtr CreateLayerNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const std::vector<int> &shape,
const float epsilon) const;
VarPtr input_;
VarPtr gamma_;
VarPtr beta_;
VarPtr epsilon_;
};

} // namespace opt
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_

Loading…
Cancel
Save