Merge pull request !6940 from zhaozhenlong/lite/op/scale_relutags/v1.1.0
| @@ -77,3 +77,147 @@ void DoScale(float *in_data, float *out_data, float *scale, float *offset, int t | |||
| scale_param->inner_size_); | |||
| } | |||
| } | |||
| void ScaleInnerRelu(float *in_data, float *out_data, float *scale, float *offset, int outer_start, int outer_end, | |||
| int axis_size, int inner_size) { | |||
| #ifdef ENABLE_ARM64 | |||
| float32x4_t zeros = {0, 0, 0, 0}; | |||
| #endif | |||
| for (int out = outer_start; out < outer_end; out++) { | |||
| int out_offset = out * axis_size * inner_size; | |||
| for (int i = 0; i < axis_size; i++) { | |||
| int axis_offset = out_offset + i * inner_size; | |||
| int in_index = 0; | |||
| #ifdef ENABLE_ARM64 | |||
| for (; in_index < inner_size - 4; in_index += 4) { | |||
| int in_offset = axis_offset + in_index; | |||
| float32x4_t data = vld1q_f32(in_data + in_offset); | |||
| float32x4_t scale_4 = vdupq_n_f32(scale[i]); | |||
| float32x4_t offset_4 = vdupq_n_f32(offset[i]); | |||
| float32x4_t tmp = vfmaq_f32(offset_4, data, scale_4); | |||
| float32x4_t result = vmaxq_f32(tmp, zeros); | |||
| vst1q_f32(out_data + in_offset, result); | |||
| } | |||
| #endif | |||
| for (; in_index < inner_size; in_index++) { | |||
| int in_offset = axis_offset + in_index; | |||
| float tmp = in_data[in_offset] * scale[i] + offset[i]; | |||
| out_data[in_offset] = tmp > 0.0f ? tmp : 0.0f; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void ScaleAxisRelu(float *in_data, float *out_data, float *scale, float *offset, int outer_start, int outer_end, | |||
| int axis_size) { | |||
| #ifdef ENABLE_ARM64 | |||
| float32x4_t zeros = {0, 0, 0, 0}; | |||
| #endif | |||
| for (int out = outer_start; out < outer_end; out++) { | |||
| int out_offset = out * axis_size; | |||
| int index = 0; | |||
| #ifdef ENABLE_ARM64 | |||
| for (; index < axis_size - 4; index += 4) { | |||
| int in_offset = out_offset + index; | |||
| float32x4_t data = vld1q_f32(in_data + in_offset); | |||
| float32x4_t scale_4 = vld1q_f32(scale + index); | |||
| float32x4_t offset_4 = vld1q_f32(offset + index); | |||
| float32x4_t tmp = vfmaq_f32(offset_4, data, scale_4); | |||
| float32x4_t result = vmaxq_f32(tmp, zeros); | |||
| vst1q_f32(out_data + in_offset, result); | |||
| } | |||
| #endif | |||
| for (; index < axis_size; index++) { | |||
| int in_offset = out_offset + index; | |||
| float tmp = in_data[in_offset] * scale[index] + offset[index]; | |||
| out_data[in_offset] = tmp > 0.0f ? tmp : 0.0f; | |||
| } | |||
| } | |||
| } | |||
| void DoScaleRelu(float *in_data, float *out_data, float *scale, float *offset, int task_id, | |||
| ScaleParameter *scale_param) { | |||
| int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_); | |||
| int outer_start = task_id * outer_step; | |||
| int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); | |||
| if (scale_param->inner_size_ == 1) { | |||
| ScaleAxisRelu(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); | |||
| } else { | |||
| ScaleInnerRelu(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, | |||
| scale_param->inner_size_); | |||
| } | |||
| } | |||
| void ScaleInnerRelu6(float *in_data, float *out_data, float *scale, float *offset, int outer_start, int outer_end, | |||
| int axis_size, int inner_size) { | |||
| #ifdef ENABLE_ARM64 | |||
| float32x4_t zeros = {0, 0, 0, 0}; | |||
| float32x4_t bounds = {6, 6, 6, 6}; | |||
| #endif | |||
| for (int out = outer_start; out < outer_end; out++) { | |||
| int out_offset = out * axis_size * inner_size; | |||
| for (int i = 0; i < axis_size; i++) { | |||
| int axis_offset = out_offset + i * inner_size; | |||
| int in_index = 0; | |||
| #ifdef ENABLE_ARM64 | |||
| for (; in_index < inner_size - 4; in_index += 4) { | |||
| int in_offset = axis_offset + in_index; | |||
| float32x4_t data = vld1q_f32(in_data + in_offset); | |||
| float32x4_t scale_4 = vdupq_n_f32(scale[i]); | |||
| float32x4_t offset_4 = vdupq_n_f32(offset[i]); | |||
| float32x4_t tmp = vfmaq_f32(offset_4, data, scale_4); | |||
| float32x4_t result = vminq_f32(vmaxq_f32(tmp, zeros), bounds); | |||
| vst1q_f32(out_data + in_offset, result); | |||
| } | |||
| #endif | |||
| for (; in_index < inner_size; in_index++) { | |||
| int in_offset = axis_offset + in_index; | |||
| float tmp = in_data[in_offset] * scale[i] + offset[i]; | |||
| out_data[in_offset] = MSMIN(MSMAX(tmp, 0.0f), 6.0f); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void ScaleAxisRelu6(float *in_data, float *out_data, float *scale, float *offset, int outer_start, int outer_end, | |||
| int axis_size) { | |||
| #ifdef ENABLE_ARM64 | |||
| float32x4_t zeros = {0, 0, 0, 0}; | |||
| float32x4_t bounds = {6, 6, 6, 6}; | |||
| #endif | |||
| for (int out = outer_start; out < outer_end; out++) { | |||
| int out_offset = out * axis_size; | |||
| int index = 0; | |||
| #ifdef ENABLE_ARM64 | |||
| for (; index < axis_size - 4; index += 4) { | |||
| int in_offset = out_offset + index; | |||
| float32x4_t data = vld1q_f32(in_data + in_offset); | |||
| float32x4_t scale_4 = vld1q_f32(scale + index); | |||
| float32x4_t offset_4 = vld1q_f32(offset + index); | |||
| float32x4_t tmp = vfmaq_f32(offset_4, data, scale_4); | |||
| float32x4_t result = vminq_f32(vmaxq_f32(tmp, zeros), bounds); | |||
| vst1q_f32(out_data + in_offset, result); | |||
| } | |||
| #endif | |||
| for (; index < axis_size; index++) { | |||
| int in_offset = out_offset + index; | |||
| float tmp = in_data[in_offset] * scale[index] + offset[index]; | |||
| out_data[in_offset] = MSMIN(MSMAX(tmp, 0.0f), 6.0f); | |||
| } | |||
| } | |||
| } | |||
| void DoScaleRelu6(float *in_data, float *out_data, float *scale, float *offset, int task_id, | |||
| ScaleParameter *scale_param) { | |||
| int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_); | |||
| int outer_start = task_id * outer_step; | |||
| int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); | |||
| if (scale_param->inner_size_ == 1) { | |||
| ScaleAxisRelu6(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); | |||
| } else { | |||
| ScaleInnerRelu6(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, | |||
| scale_param->inner_size_); | |||
| } | |||
| } | |||
| @@ -23,6 +23,10 @@ | |||
| extern "C" { | |||
| #endif | |||
| void DoScale(float *in_data, float *out_data, float *scale, float *offset, int task_id, ScaleParameter *scale_param); | |||
| void DoScaleRelu(float *in_data, float *out_data, float *scale, float *offset, int task_id, | |||
| ScaleParameter *scale_param); | |||
| void DoScaleRelu6(float *in_data, float *out_data, float *scale, float *offset, int task_id, | |||
| ScaleParameter *scale_param); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -33,6 +33,7 @@ typedef struct ScaleParameter { | |||
| int scale_zp_; | |||
| int offset_zp_; | |||
| int output_zp_; | |||
| int activation_type_; | |||
| } ScaleParameter; | |||
| #endif // MINDSPORE_LITE_NNACL_SCALE_H_ | |||
| @@ -416,6 +416,7 @@ table BNGradInput { | |||
| } | |||
| table Scale { | |||
| axis: int; | |||
| activationType: ActivationType = 0; | |||
| } | |||
| table Eltwise { | |||
| @@ -20,12 +20,16 @@ namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Scale::GetAxis() const { return this->primitive_->value.AsScale()->axis; } | |||
| void Scale::SetAxis(int axis) { this->primitive_->value.AsScale()->axis = axis; } | |||
| int Scale::GetActivationType() const { return this->primitive_->value.AsScale()->activationType; } | |||
| void Scale::SetActivationType(int activation_type) { | |||
| this->primitive_->value.AsScale()->activationType = (schema::ActivationType)activation_type; | |||
| } | |||
| #else | |||
| int Scale::GetAxis() const { return this->primitive_->value_as_Scale()->axis(); } | |||
| int Scale::GetActivationType() const { return this->primitive_->value_as_Scale()->activationType(); } | |||
| int Scale::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| @@ -34,7 +38,7 @@ int Scale::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: | |||
| MS_LOG(ERROR) << "value_as_Scale return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateScale(*fbb, attr->axis()); | |||
| auto val_offset = schema::CreateScale(*fbb, attr->axis(), attr->activationType()); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Scale, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| @@ -32,6 +32,7 @@ class Scale : public PrimitiveC { | |||
| Scale() = default; | |||
| explicit Scale(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetAxis(int axis); | |||
| void SetActivationType(int activation_type); | |||
| #else | |||
| Scale() = default; | |||
| @@ -39,6 +40,7 @@ class Scale : public PrimitiveC { | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int GetAxis() const; | |||
| int GetActivationType() const; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -943,6 +943,7 @@ OpParameter *PopulateScaleParameter(const mindspore::lite::PrimitiveC *primitive | |||
| scale_param->op_parameter_.type_ = primitive->Type(); | |||
| auto param = reinterpret_cast<mindspore::lite::Scale *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||
| scale_param->axis_ = param->GetAxis(); | |||
| scale_param->activation_type_ = param->GetActivationType(); | |||
| return reinterpret_cast<OpParameter *>(scale_param); | |||
| } | |||
| @@ -35,52 +35,56 @@ ScaleCPUKernel::~ScaleCPUKernel() { | |||
| scale_ = nullptr; | |||
| } | |||
| } | |||
| if (offset_ != nullptr) { | |||
| free(offset_); | |||
| offset_ = nullptr; | |||
| if (scale_param_->const_offset_) { | |||
| if (offset_ != nullptr) { | |||
| free(offset_); | |||
| offset_ = nullptr; | |||
| } | |||
| } | |||
| } | |||
| int ScaleCPUKernel::InitScaleOffset() { | |||
| auto scale_tensor = in_tensors_.at(1); | |||
| float *scale_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->data_c()); | |||
| if (scale_ptr != nullptr) { | |||
| if (reinterpret_cast<float *>(scale_tensor->data_c()) != nullptr) { | |||
| scale_param_->const_scale_ = true; | |||
| if (scale_ != nullptr) { | |||
| free(scale_); | |||
| scale_ = nullptr; | |||
| } | |||
| scale_ = reinterpret_cast<float *>(malloc(scale_tensor->ElementsNum() * sizeof(float))); | |||
| if (scale_ == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||
| return RET_ERROR; | |||
| } | |||
| memcpy(scale_, scale_ptr, scale_tensor->ElementsNum() * sizeof(float)); | |||
| memcpy(scale_, scale_tensor->data_c(), scale_tensor->ElementsNum() * sizeof(float)); | |||
| } else { | |||
| scale_param_->const_scale_ = false; | |||
| scale_ = nullptr; | |||
| } | |||
| if (offset_ != nullptr) { | |||
| free(offset_); | |||
| offset_ = nullptr; | |||
| } | |||
| offset_ = reinterpret_cast<float *>(malloc(scale_param_->axis_size_ * sizeof(float))); | |||
| if (offset_ == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||
| return RET_ERROR; | |||
| } | |||
| memset(offset_, 0, scale_param_->axis_size_ * sizeof(float)); | |||
| if (in_tensors_.size() == 3) { | |||
| if (in_tensors_.size() == 2) { | |||
| scale_param_->const_offset_ = true; | |||
| offset_ = reinterpret_cast<float *>(malloc(scale_tensor->ElementsNum() * sizeof(float))); | |||
| if (offset_ == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| memset(offset_, 0, scale_tensor->ElementsNum() * sizeof(float)); | |||
| } else if (in_tensors_.size() == 3 && reinterpret_cast<float *>(in_tensors_.at(2)->data_c()) != nullptr) { | |||
| scale_param_->const_offset_ = true; | |||
| auto offset_tensor = in_tensors_.at(2); | |||
| if (offset_tensor->data_c() != nullptr) { | |||
| memcpy(offset_, offset_tensor->data_c(), offset_tensor->ElementsNum() * sizeof(float)); | |||
| MS_ASSERT(scale_tensor->ElementsNum() == offset_tensor->ElementsNum()); | |||
| offset_ = reinterpret_cast<float *>(malloc(offset_tensor->ElementsNum() * sizeof(float))); | |||
| if (offset_ == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| memcpy(offset_, offset_tensor->data_c(), offset_tensor->ElementsNum() * sizeof(float)); | |||
| } else { | |||
| scale_param_->const_offset_ = false; | |||
| offset_ = nullptr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ScaleCPUKernel::InitParameter() { | |||
| int ScaleCPUKernel::CalculateParameter() { | |||
| auto in_tensor = in_tensors_.at(0); | |||
| auto in_shape = in_tensor->shape(); | |||
| auto scale_tensor = in_tensors_.at(1); | |||
| @@ -118,32 +122,44 @@ int ScaleCPUKernel::Init() { | |||
| MS_LOG(ERROR) << "inputs to Scale operator should be 2 or 3, but " << in_tensors_.size() << " is given."; | |||
| return RET_ERROR; | |||
| } | |||
| auto ret = InitScaleOffset(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Scale fp32 InitScaleOffset failed."; | |||
| return RET_ERROR; | |||
| } | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| ReSize(); | |||
| return RET_OK; | |||
| } | |||
| int ScaleCPUKernel::ReSize() { | |||
| auto ret = InitParameter(); | |||
| auto ret = CalculateParameter(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Scale fp32 InitParameter failed."; | |||
| return RET_ERROR; | |||
| } | |||
| ret = InitScaleOffset(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Scale fp32 InitScaleOffset failed."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ScaleCPUKernel::Scale(int task_id) { | |||
| DoScale(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_); | |||
| switch (scale_param_->activation_type_) { | |||
| case schema::ActivationType_RELU6: | |||
| DoScaleRelu6(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_); | |||
| break; | |||
| case schema::ActivationType_RELU: | |||
| DoScaleRelu(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_); | |||
| break; | |||
| case schema::ActivationType_NO_ACTIVATION: | |||
| DoScale(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_); | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Scale does not support activation type " << scale_param_->activation_type_; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -164,14 +180,15 @@ int ScaleCPUKernel::Run() { | |||
| return ret; | |||
| } | |||
| auto in_tensor = in_tensors_.front(); | |||
| input_ptr_ = reinterpret_cast<float *>(in_tensor->MutableData()); | |||
| if (scale_ == nullptr) { | |||
| input_ptr_ = reinterpret_cast<float *>(in_tensor->data_c()); | |||
| if (!scale_param_->const_scale_) { | |||
| auto scale_tensor = in_tensors_[1]; | |||
| scale_ = reinterpret_cast<float *>(scale_tensor->MutableData()); | |||
| scale_ = reinterpret_cast<float *>(scale_tensor->data_c()); | |||
| } | |||
| if (offset_ == nullptr) { | |||
| if (!scale_param_->const_offset_) { | |||
| MS_ASSERT(in_tensors_.size() == 3); | |||
| auto offset_tensor = in_tensors_.at(2); | |||
| memcpy(offset_, offset_tensor->data_c(), offset_tensor->ElementsNum() * sizeof(float)); | |||
| offset_ = reinterpret_cast<float *>(offset_tensor->data_c()); | |||
| } | |||
| auto out_tensor = out_tensors_.front(); | |||
| output_ptr_ = reinterpret_cast<float *>(out_tensor->MutableData()); | |||
| @@ -36,7 +36,7 @@ class ScaleCPUKernel : public LiteKernel { | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int InitParameter(); | |||
| int CalculateParameter(); | |||
| int InitScaleOffset(); | |||
| int Scale(int task_id); | |||
| @@ -0,0 +1,160 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include "mindspore/lite/src/lite_kernel.h" | |||
| #include "mindspore/lite/src/tensor.h" | |||
| #include "common/common_test.h" | |||
| #include "nnacl/pad_parameter.h" | |||
| #include "mindspore/lite/src/kernel_registry.h" | |||
| #include "mindspore/lite/schema/ops_generated.h" | |||
| #include "nnacl/fp32/scale.h" | |||
| using mindspore::schema::ActivationType; | |||
| using mindspore::schema::ActivationType_NO_ACTIVATION; | |||
| using mindspore::schema::ActivationType_RELU; | |||
| using mindspore::schema::ActivationType_RELU6; | |||
| using mindspore::schema::Format_NHWC; | |||
| namespace mindspore { | |||
| class TestScaleFp32 : public mindspore::CommonTest { | |||
| public: | |||
| TestScaleFp32() = default; | |||
| void Prepare(const std::vector<int> &input_shape, const std::vector<int> &scale_shape, | |||
| const std::vector<int> &offset_shape, const std::vector<int> &output_shape, float *input_data, | |||
| float *scale_data, float *offset_data, float *output_data, int axis, ActivationType act_type, | |||
| const int thread_num); | |||
| void TearDown() override; | |||
| public: | |||
| float err_tol = 1e-5; | |||
| lite::Tensor in_tensor_; | |||
| lite::Tensor scale_tensor_; | |||
| lite::Tensor offset_tensor_; | |||
| lite::Tensor out_tensor_; | |||
| ScaleParameter param_; | |||
| std::vector<lite::Tensor *> inputs_{&in_tensor_, &scale_tensor_, &offset_tensor_}; | |||
| std::vector<lite::Tensor *> outputs_{&out_tensor_}; | |||
| kernel::KernelKey desc_ = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Scale}; | |||
| lite::InnerContext ctx_ = lite::InnerContext(); | |||
| kernel::KernelCreator creator_ = nullptr; | |||
| kernel::LiteKernel *kernel_ = nullptr; | |||
| }; | |||
| void TestScaleFp32::TearDown() { | |||
| in_tensor_.SetData(nullptr); | |||
| scale_tensor_.SetData(nullptr); | |||
| offset_tensor_.SetData(nullptr); | |||
| out_tensor_.SetData(nullptr); | |||
| } | |||
| void TestScaleFp32::Prepare(const std::vector<int> &input_shape, const std::vector<int> &scale_shape, | |||
| const std::vector<int> &offset_shape, const std::vector<int> &output_shape, | |||
| float *input_data, float *scale_data, float *offset_data, float *output_data, int axis, | |||
| ActivationType act_type, const int thread_num) { | |||
| in_tensor_.set_data_type(kNumberTypeFloat32); | |||
| in_tensor_.SetFormat(Format_NHWC); | |||
| in_tensor_.set_shape(input_shape); | |||
| scale_tensor_.set_data_type(kNumberTypeFloat32); | |||
| scale_tensor_.SetFormat(Format_NHWC); | |||
| scale_tensor_.set_shape(scale_shape); | |||
| offset_tensor_.set_data_type(kNumberTypeFloat32); | |||
| offset_tensor_.SetFormat(Format_NHWC); | |||
| offset_tensor_.set_shape(offset_shape); | |||
| out_tensor_.set_data_type(kNumberTypeFloat32); | |||
| out_tensor_.set_shape(output_shape); | |||
| in_tensor_.SetData(input_data); | |||
| scale_tensor_.SetData(scale_data); | |||
| offset_tensor_.SetData(offset_data); | |||
| out_tensor_.SetData(output_data); | |||
| param_.activation_type_ = act_type; | |||
| param_.axis_ = axis; | |||
| ctx_ = lite::InnerContext(); | |||
| ctx_.thread_num_ = thread_num; | |||
| ctx_.Init(); | |||
| creator_ = lite::KernelRegistry::GetInstance()->GetCreator(desc_); | |||
| ASSERT_NE(creator_, nullptr); | |||
| kernel_ = creator_(inputs_, outputs_, reinterpret_cast<OpParameter *>(¶m_), &ctx_, desc_, nullptr); | |||
| ASSERT_NE(kernel_, nullptr); | |||
| } | |||
| TEST_F(TestScaleFp32, ScaleNoAct) { | |||
| std::vector<int> input_shape{1, 2, 2, 3}; | |||
| std::vector<int> scale_shape{3}; | |||
| std::vector<int> offset_shape{3}; | |||
| std::vector<int> output_shape{1, 2, 2, 3}; | |||
| float in_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0}; | |||
| float scale_data[3] = {1.0, 2.0, 3.0}; | |||
| float offset_data[3] = {1.0, 1.0, 1.0}; | |||
| float out_data[12] = {0}; | |||
| int axis = -1; | |||
| int thread_num = 2; | |||
| Prepare(input_shape, scale_shape, offset_shape, output_shape, in_data, scale_data, offset_data, out_data, axis, | |||
| ActivationType_NO_ACTIVATION, thread_num); | |||
| auto ret = kernel_->Run(); | |||
| EXPECT_EQ(0, ret); | |||
| std::vector<float> expect{1.0, 3.0, 7.0, 4.0, 9.0, 16.0, 7.0, 15.0, 25.0, 10.0, 21.0, 34.0}; | |||
| CompareOutputData(out_data, expect.data(), 12, err_tol); | |||
| } | |||
| TEST_F(TestScaleFp32, ScaleRelu) { | |||
| std::vector<int> input_shape{1, 2, 2, 3}; | |||
| std::vector<int> scale_shape{3}; | |||
| std::vector<int> offset_shape{3}; | |||
| std::vector<int> output_shape{1, 2, 2, 3}; | |||
| float in_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0}; | |||
| float scale_data[3] = {1.0, 2.0, 3.0}; | |||
| float offset_data[3] = {-5.0, -5.0, -5.0}; | |||
| float out_data[12] = {0}; | |||
| int axis = -1; | |||
| int thread_num = 2; | |||
| Prepare(input_shape, scale_shape, offset_shape, output_shape, in_data, scale_data, offset_data, out_data, axis, | |||
| ActivationType_RELU, thread_num); | |||
| auto ret = kernel_->Run(); | |||
| EXPECT_EQ(0, ret); | |||
| std::vector<float> expect{0.0, 0.0, 1.0, 0.0, 3.0, 10.0, 1.0, 9.0, 19.0, 4.0, 15.0, 28.0}; | |||
| CompareOutputData(out_data, expect.data(), 12, err_tol); | |||
| } | |||
| TEST_F(TestScaleFp32, ScaleRelu6) { | |||
| std::vector<int> input_shape{1, 2, 2, 3}; | |||
| std::vector<int> scale_shape{3}; | |||
| std::vector<int> offset_shape{3}; | |||
| std::vector<int> output_shape{1, 2, 2, 3}; | |||
| float in_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0}; | |||
| float scale_data[3] = {1.0, 2.0, 3.0}; | |||
| float offset_data[3] = {-5.0, -5.0, -5.0}; | |||
| float out_data[12] = {0}; | |||
| int axis = -1; | |||
| int thread_num = 2; | |||
| Prepare(input_shape, scale_shape, offset_shape, output_shape, in_data, scale_data, offset_data, out_data, axis, | |||
| ActivationType_RELU6, thread_num); | |||
| auto ret = kernel_->Run(); | |||
| EXPECT_EQ(0, ret); | |||
| std::vector<float> expect{0.0, 0.0, 1.0, 0.0, 3.0, 6.0, 1.0, 6.0, 6.0, 4.0, 6.0, 6.0}; | |||
| CompareOutputData(out_data, expect.data(), 12, err_tol); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -144,23 +144,26 @@ STATUS MulAddFusionPass::AddNewScaleNode(MetaGraphT *graph, const std::unique_pt | |||
| // NHWC | |||
| int shape_size = graph->allTensors.at(addBiasIndex)->dims.size(); | |||
| scaleParam->axis = 0 - shape_size; | |||
| mulNode->primitive->value.value = scaleParam.release(); | |||
| mulNode->inputIndex.push_back(addBiasIndex); | |||
| if (addNode->primitive->value.AsAdd()->activationType != ActivationType_NO_ACTIVATION) { | |||
| auto activationType = addNode->primitive->value.AsAdd()->activationType; | |||
| if (activationType == ActivationType_RELU || activationType == ActivationType_RELU6 || | |||
| activationType == ActivationType_NO_ACTIVATION) { | |||
| // delete addnode | |||
| scaleParam->activationType = activationType; | |||
| auto status = IsolateOneWayNode(graph, addNode); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode failed"; | |||
| return status; | |||
| } | |||
| } else { | |||
| // repace addnode as activation | |||
| std::unique_ptr<ActivationT> activationParam(new ActivationT()); | |||
| activationParam->type = addNode->primitive->value.AsAdd()->activationType; | |||
| addNode->primitive->value.type = schema::PrimitiveType_Activation; | |||
| addNode->primitive->value.value = activationParam.release(); | |||
| addNode->inputIndex.pop_back(); | |||
| return RET_OK; | |||
| } | |||
| // delete addnode | |||
| auto status = IsolateOneWayNode(graph, addNode); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode failed"; | |||
| return status; | |||
| } | |||
| mulNode->primitive->value.value = scaleParam.release(); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||