| @@ -174,6 +174,19 @@ union PrimitiveType { | |||||
| Where, | Where, | ||||
| OneHot, | OneHot, | ||||
| Lstm, | Lstm, | ||||
| Conv2DGradFilter, | |||||
| Conv2DGradInput, | |||||
| PoolingGrad, | |||||
| BNGradInput, | |||||
| OptMomentum, | |||||
| BiasGrad, | |||||
| SoftmaxCrossEntropy, | |||||
| AddGrad, | |||||
| SubGrad, | |||||
| MulGrad, | |||||
| DivGrad, | |||||
| PowerGrad, | |||||
| ActivationGrad, | |||||
| PriorBox | PriorBox | ||||
| } | } | ||||
| @@ -55,7 +55,25 @@ enum ActivationType : byte { | |||||
| LINEAR = 15, | LINEAR = 15, | ||||
| UNKNOW = 16 | UNKNOW = 16 | ||||
| } | } | ||||
| enum ActivationGradType : byte { | |||||
| NO_ACTIVATION = 0, | |||||
| RELU = 1, | |||||
| SIGMOID = 2, | |||||
| RELU6 = 3, | |||||
| ELU = 4, | |||||
| LEAKY_RELU = 5, | |||||
| ABS = 6, | |||||
| RELU1 = 7, | |||||
| SOFTSIGN = 8, | |||||
| SOFTPLUS = 9, | |||||
| TANH = 10, | |||||
| SELU = 11, | |||||
| HSWISH = 12, | |||||
| HSIGMOID = 13, | |||||
| THRESHOLDRELU = 14, | |||||
| LINEAR = 15, | |||||
| UNKNOW = 16 | |||||
| } | |||||
| enum ReduceType : byte { | enum ReduceType : byte { | ||||
| REDUCE_MAX = 0, | REDUCE_MAX = 0, | ||||
| REDUCE_MEAN = 1, | REDUCE_MEAN = 1, | ||||
| @@ -125,6 +143,10 @@ table SoftMax { | |||||
| table Activation { | table Activation { | ||||
| type: ActivationType = 0; | type: ActivationType = 0; | ||||
| } | } | ||||
| table ActivationGrad { | |||||
| type: ActivationGradType = 0; | |||||
| } | |||||
| table Conv2D { | table Conv2D { | ||||
| format: Format = 0; | format: Format = 0; | ||||
| @@ -146,7 +168,45 @@ table Conv2D { | |||||
| activationType: ActivationType = 0; | activationType: ActivationType = 0; | ||||
| } | } | ||||
| table FusedBatchNorm { | |||||
| table Conv2DGradFilter { | |||||
| format: Format = 0; | |||||
| group: int; | |||||
| channelIn: int; | |||||
| channelOut: int; | |||||
| kernelW: int; | |||||
| kernelH: int; | |||||
| strideW: int; | |||||
| strideH: int; | |||||
| padMode: PadMode; | |||||
| padUp: int; | |||||
| padDown: int; | |||||
| padLeft: int; | |||||
| padRight: int; | |||||
| dilateW: int; | |||||
| dilateH: int; | |||||
| hasBias: bool = false; | |||||
| activationType: ActivationType = 0; | |||||
| } | |||||
| table Conv2DGradInput { | |||||
| format: Format = 0; | |||||
| group: int; | |||||
| channelIn: int; | |||||
| channelOut: int; | |||||
| kernelW: int; | |||||
| kernelH: int; | |||||
| strideW: int; | |||||
| strideH: int; | |||||
| padMode: PadMode; | |||||
| padUp: int; | |||||
| padDown: int; | |||||
| padLeft: int; | |||||
| padRight: int; | |||||
| dilateW: int; | |||||
| dilateH: int; | |||||
| hasBias: bool = false; | |||||
| activationType: ActivationType = 0; | |||||
| }table FusedBatchNorm { | |||||
| epsilon: float = 0.00001; // eg. epsilon=0.001 | epsilon: float = 0.00001; // eg. epsilon=0.001 | ||||
| momentum: float = 0.9; | momentum: float = 0.9; | ||||
| spatial: int = 1; | spatial: int = 1; | ||||
| @@ -156,6 +216,31 @@ table CaffeBatchNorm { | |||||
| epsilon: float; // eg. epsilon=0.001 | epsilon: float; // eg. epsilon=0.001 | ||||
| } | } | ||||
| table BiasGrad { | |||||
| axis: [int]; | |||||
| } | |||||
| table SoftmaxCrossEntropy { | |||||
| axis: [int]; | |||||
| } | |||||
| table PoolingGrad { | |||||
| format: Format = 0; | |||||
| poolingMode: PoolMode; | |||||
| global: bool = false; | |||||
| windowW: int; | |||||
| windowH: int; | |||||
| strideW: int; | |||||
| strideH: int; | |||||
| padMode: PadMode; | |||||
| padUp: int; | |||||
| padDown: int; | |||||
| padLeft: int; | |||||
| padRight: int; | |||||
| roundMode: RoundMode; | |||||
| } | |||||
| table Shape { | table Shape { | ||||
| } | } | ||||
| @@ -286,7 +371,10 @@ table DeConv2D { | |||||
| hasBias: bool = false; | hasBias: bool = false; | ||||
| activationType: ActivationType = 0; | activationType: ActivationType = 0; | ||||
| } | } | ||||
| table BNGradInput { | |||||
| eps : float; | |||||
| channels: int; | |||||
| } | |||||
| table Scale { | table Scale { | ||||
| format: Format = 0; | format: Format = 0; | ||||
| } | } | ||||
| @@ -307,6 +395,17 @@ table Mul { | |||||
| table Div { | table Div { | ||||
| } | } | ||||
| table AddGrad { | |||||
| } | |||||
| table SubGrad { | |||||
| } | |||||
| table MulGrad { | |||||
| } | |||||
| table DivGrad { | |||||
| } | |||||
| table RealDiv { | table RealDiv { | ||||
| } | } | ||||
| @@ -389,7 +488,11 @@ table Power { | |||||
| scale: float; | scale: float; | ||||
| shift: float; | shift: float; | ||||
| } | } | ||||
| table PowerGrad { | |||||
| power: float; | |||||
| scale: float; | |||||
| shift: float; | |||||
| } | |||||
| table ArgMax { | table ArgMax { | ||||
| axis: int; | axis: int; | ||||
| outMaxValue: bool; | outMaxValue: bool; | ||||
| @@ -712,6 +815,10 @@ table SquaredDifference { | |||||
| table TupleGetItem { | table TupleGetItem { | ||||
| } | } | ||||
| table OptMomentum { | |||||
| } | |||||
| table Where{ | table Where{ | ||||
| } | } | ||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <cmath> | |||||
| #include <cstddef> | |||||
| #include <iostream> | |||||
| #include "src/common/file_utils.h" | |||||
| #include "src/common/file_utils_ext.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| static int CompareOutputRelativeData(float *output_data, float *correct_data, int data_size) { | |||||
| float error = 0; | |||||
| // relative error | |||||
| float diffSum = 0.0f; | |||||
| float sum = 0.0f; | |||||
| for (int i = 0; i < data_size; i++) { | |||||
| sum += std::abs(correct_data[i]); | |||||
| } | |||||
| for (int i = 0; i < data_size; i++) { | |||||
| float diff = std::abs(output_data[i] - correct_data[i]); | |||||
| diffSum += diff; | |||||
| } | |||||
| error = diffSum / sum; | |||||
| if (error > 1e-4) { | |||||
| std::cout << "has accuracy error!\n" << error << "\n"; | |||||
| return 1; | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| int CompareRelativeOutput(float *output_data, std::string file_path) { | |||||
| size_t output_size; | |||||
| auto ground_truth = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); | |||||
| size_t output_num = output_size / sizeof(float); | |||||
| std::cout << "output num : " << output_num << "\n"; | |||||
| return CompareOutputRelativeData(output_data, ground_truth, output_num); | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,28 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_COMMON_FILE_UTILS_EXT_H_ | |||||
| #define MINDSPORE_LITE_COMMON_FILE_UTILS_EXT_H_ | |||||
| #include <string> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| int CompareRelativeOutput(float *output_data, std::string file_path); | |||||
| } | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_COMMON_FILE_UTILS_EXT_H_ | |||||
| @@ -64,7 +64,7 @@ class LiteKernel { | |||||
| LiteKernel() = default; | LiteKernel() = default; | ||||
| explicit LiteKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | explicit LiteKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | ||||
| const std::vector<lite::tensor::Tensor *> &outputs) | const std::vector<lite::tensor::Tensor *> &outputs) | ||||
| : opParameter(parameter), inputs_(inputs), outputs_(outputs) { | |||||
| : opParameter(parameter), inputs_(inputs), outputs_(outputs), train_mode(false) { | |||||
| this->in_kernel_.clear(); | this->in_kernel_.clear(); | ||||
| this->out_kernel_.clear(); | this->out_kernel_.clear(); | ||||
| } | } | ||||
| @@ -77,7 +77,10 @@ class LiteKernel { | |||||
| virtual int Run() { return -1; } | virtual int Run() { return -1; } | ||||
| std::string Name() { return this->name; } | std::string Name() { return this->name; } | ||||
| virtual void train() { train_mode = true; } | |||||
| virtual bool is_train() { return train_mode == true; } | |||||
| virtual void eval() { train_mode = false; } | |||||
| virtual bool is_eval() { return train_mode == false; } | |||||
| void set_name(const std::string &name) { this->name = name; } | void set_name(const std::string &name) { this->name = name; } | ||||
| schema::PrimitiveType type() { return (schema::PrimitiveType)this->opParameter->type_; } | schema::PrimitiveType type() { return (schema::PrimitiveType)this->opParameter->type_; } | ||||
| @@ -117,6 +120,7 @@ class LiteKernel { | |||||
| std::vector<lite::tensor::Tensor *> outputs_; | std::vector<lite::tensor::Tensor *> outputs_; | ||||
| std::vector<LiteKernel *> in_kernel_; | std::vector<LiteKernel *> in_kernel_; | ||||
| std::vector<LiteKernel *> out_kernel_; | std::vector<LiteKernel *> out_kernel_; | ||||
| bool train_mode; | |||||
| }; | }; | ||||
| class SubGraphKernel : public LiteKernel { | class SubGraphKernel : public LiteKernel { | ||||
| @@ -0,0 +1,110 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32/activation_grad.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::ActivationGradType_HSWISH; | |||||
| using mindspore::schema::ActivationGradType_LEAKY_RELU; | |||||
| using mindspore::schema::ActivationGradType_RELU; | |||||
| using mindspore::schema::ActivationGradType_RELU6; | |||||
| using mindspore::schema::PrimitiveType_ActivationGrad; | |||||
| namespace mindspore::kernel { | |||||
| int ActivationGradCPUKernel::Init() { | |||||
| outputs_[0]->set_shape(inputs_[0]->shape()); | |||||
| return RET_OK; | |||||
| } | |||||
| int ActivationGradCPUKernel::ReSize() { return RET_OK; } | |||||
| int ActivationGradCPUKernel::DoActivation(int task_id) { | |||||
| auto yt_addr = reinterpret_cast<float *>(inputs_.at(0)->Data()); | |||||
| auto input_addr = reinterpret_cast<float *>(inputs_.at(1)->Data()); | |||||
| auto output_addr = reinterpret_cast<float *>(outputs_.at(0)->Data()); | |||||
| auto length = inputs_.at(0)->ElementsNum(); | |||||
| auto error_code = RET_OK; | |||||
| if (type_ == schema::ActivationGradType_RELU) { | |||||
| error_code = ReluGrad(yt_addr, input_addr, length, output_addr); | |||||
| } else if (type_ == schema::ActivationGradType_RELU6) { | |||||
| error_code = Relu6Grad(yt_addr, input_addr, length, output_addr); | |||||
| } else if (type_ == schema::ActivationGradType_LEAKY_RELU) { | |||||
| error_code = LReluGrad(yt_addr, input_addr, length, output_addr, alpha_); | |||||
| } else if (type_ == schema::ActivationGradType_SIGMOID) { | |||||
| error_code = SigmoidGrad(yt_addr, input_addr, length, output_addr); | |||||
| } else if (type_ == schema::ActivationGradType_TANH) { | |||||
| error_code = TanhGrad(yt_addr, input_addr, length, output_addr); | |||||
| } else if (type_ == schema::ActivationGradType_HSWISH) { | |||||
| error_code = HSwishGrad(yt_addr, input_addr, length, output_addr); | |||||
| } else if (type_ == schema::ActivationGradType_HSIGMOID) { | |||||
| error_code = HSigmoidGrad(yt_addr, input_addr, length, output_addr); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Activation type error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (error_code != RET_OK) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ActivationGradRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||||
| auto activationGrad_kernel = reinterpret_cast<ActivationGradCPUKernel *>(cdata); | |||||
| auto error_code = activationGrad_kernel->DoActivation(task_id); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "ActivationGradRun error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ActivationGradCPUKernel::Run() { | |||||
| int error_code = LiteBackendParallelLaunch(ActivationGradRun, this, thread_count_); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "Activation function error error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuActivationGradFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::Context *ctx, | |||||
| const kernel::KernelKey &desc) { | |||||
| MS_ASSERT(opParameter != nullptr); | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_ActivationGrad); | |||||
| auto *kernel = new (std::nothrow) ActivationGradCPUKernel(opParameter, inputs, outputs); | |||||
| MS_ASSERT(kernel != nullptr); | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "InferShape kernel failed, name: " << opParameter->name_ | |||||
| << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ActivationGrad, CpuActivationGradFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,50 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "ir/anf.h" | |||||
| #include "src/runtime/kernel/arm/opclib/activation_grad.h" | |||||
| namespace mindspore::kernel { | |||||
| class ActivationGradCPUKernel : public LiteKernel { | |||||
| public: | |||||
| explicit ActivationGradCPUKernel(OpParameter *param, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||||
| : LiteKernel(param, inputs, outputs) { | |||||
| ActivationGradParameter *param_act_grad = reinterpret_cast<ActivationGradParameter *>(param); | |||||
| type_ = param_act_grad->type_; | |||||
| alpha_ = param_act_grad->alpha_; | |||||
| } | |||||
| ~ActivationGradCPUKernel() override = default; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int DoActivation(int task_id); | |||||
| private: | |||||
| int thread_count_; | |||||
| int type_; | |||||
| float alpha_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_GRAD_H_ | |||||
| @@ -0,0 +1,285 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/reduce_grad.h" | |||||
| #include "src/runtime/kernel/arm/fp32/arithmetic_grad.h" | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/arithmetic_grad.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| namespace mindspore::kernel { | |||||
| namespace { | |||||
| constexpr int kArithGradOpInputNum = 3; | |||||
| constexpr int kArithGradOpOutputNum = 2; | |||||
| } // namespace | |||||
| int ArithmeticGradCPUKernel::Init() { | |||||
| auto ret = InferShape(); | |||||
| return ret; | |||||
| } | |||||
| int ArithmeticGradCPUKernel::InferShape() { | |||||
| if (inputs_.size() != kArithGradOpInputNum) { | |||||
| MS_LOG(ERROR) << "The number of input must be " << kArithGradOpInputNum; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (outputs_.size() != kArithGradOpOutputNum) { | |||||
| MS_LOG(ERROR) << "The number of output must be " << kArithGradOpOutputNum; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto dy = inputs_[0]; | |||||
| auto x1 = inputs_[1]; | |||||
| auto x2 = inputs_[2]; | |||||
| auto dx1 = outputs_[0]; | |||||
| auto dx2 = outputs_[1]; | |||||
| MS_ASSERT(dy != nullptr); | |||||
| MS_ASSERT(x1 != nullptr); | |||||
| MS_ASSERT(x2 != nullptr); | |||||
| MS_ASSERT(dx1 != nullptr); | |||||
| MS_ASSERT(dx2 != nullptr); | |||||
| auto inShape0 = x1->shape(); | |||||
| auto inShape1 = x2->shape(); | |||||
| auto outShape = dy->shape(); | |||||
| if ((type() == PrimitiveType_AddGrad) || (type() == PrimitiveType_SubGrad)) { | |||||
| arithmeticParameter_->ndim_ = outShape.size(); | |||||
| auto fillDimNum0 = outShape.size() - inShape0.size(); | |||||
| auto fillDimNum1 = outShape.size() - inShape1.size(); | |||||
| int j0 = 0; | |||||
| int j1 = 0; | |||||
| for (unsigned int i = 0; i < outShape.size(); i++) { | |||||
| arithmeticParameter_->in_shape0_[i] = (i < fillDimNum0) ? 1 : inShape0[j0++]; | |||||
| arithmeticParameter_->in_shape1_[i] = (i < fillDimNum1) ? 1 : inShape1[j1++]; | |||||
| arithmeticParameter_->out_shape_[i] = outShape[i]; | |||||
| } | |||||
| } else { | |||||
| // if (inShape0.size() < inShape1.size()) | |||||
| if (dx1->ElementsNum() < dx2->ElementsNum()) { | |||||
| arithmeticParameter_->ndim_ = inShape1.size(); | |||||
| if (type() == PrimitiveType_MulGrad) | |||||
| arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul2L; | |||||
| else if (type() == PrimitiveType_DivGrad) | |||||
| arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradDiv2L; | |||||
| auto fillDimNum = inShape1.size() - inShape0.size(); // This will not work for batch! | |||||
| int j = 0; | |||||
| for (unsigned int i = 0; i < inShape1.size(); i++) { | |||||
| if (i < fillDimNum) { | |||||
| arithmeticParameter_->in_shape1_[i] = 1; | |||||
| } else { | |||||
| arithmeticParameter_->in_shape1_[i] = inShape0[j++]; | |||||
| } | |||||
| arithmeticParameter_->in_shape0_[i] = inShape1[i]; | |||||
| arithmeticParameter_->out_shape_[i] = outShape[i]; | |||||
| } | |||||
| } else if (dx2->ElementsNum() < dx1->ElementsNum()) { // if (inShape0.size() > inShape1.size()) | |||||
| arithmeticParameter_->ndim_ = inShape0.size(); | |||||
| if (type() == PrimitiveType_MulGrad) | |||||
| arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul1L; | |||||
| else if (type() == PrimitiveType_DivGrad) | |||||
| arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradDiv1L; | |||||
| arithmeticParameter_->broadcasting_ = true; | |||||
| arithmeticParameter_->ndim_ = inShape0.size(); | |||||
| int j = 0; | |||||
| auto fillDimNum = inShape0.size() - inShape1.size(); | |||||
| for (unsigned int i = 0; i < inShape0.size(); i++) { | |||||
| if (i < fillDimNum) { | |||||
| arithmeticParameter_->in_shape1_[i] = 1; | |||||
| } else { | |||||
| arithmeticParameter_->in_shape1_[i] = inShape1[j++]; | |||||
| } | |||||
| arithmeticParameter_->in_shape0_[i] = inShape0[i]; | |||||
| arithmeticParameter_->out_shape_[i] = outShape[i]; | |||||
| } | |||||
| } else { | |||||
| arithmeticParameter_->broadcasting_ = false; | |||||
| for (unsigned int i = 0; i < inShape0.size(); i++) { | |||||
| arithmeticParameter_->in_shape1_[i] = inShape1[i]; | |||||
| arithmeticParameter_->in_shape0_[i] = inShape0[i]; | |||||
| arithmeticParameter_->out_shape_[i] = outShape[i]; | |||||
| } | |||||
| } | |||||
| tile_data0 = new (std::nothrow) float[inputs_.at(0)->ElementsNum()]; | |||||
| MS_ASSERT(tile_data0 != nullptr); | |||||
| tile_data1 = new (std::nothrow) float[inputs_.at(0)->ElementsNum()]; | |||||
| MS_ASSERT(tile_data1 != nullptr); | |||||
| if (type() == PrimitiveType_DivGrad) { | |||||
| tile_data2 = new (std::nothrow) float[inputs_.at(0)->ElementsNum()]; | |||||
| MS_ASSERT(tile_data2 != nullptr); | |||||
| } | |||||
| } | |||||
| dx1->set_shape(x1->shape()); | |||||
| dx2->set_shape(x2->shape()); | |||||
| // outTensor->set_shape(out_shape); | |||||
| dx1->set_data_type(dy->data_type()); | |||||
| dx2->set_data_type(dy->data_type()); | |||||
| return RET_OK; | |||||
| } | |||||
| void ArithmeticGradCPUKernel::ArithmeticGradAdd(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, | |||||
| int dx2_size) { | |||||
| if (dx1_size == dy_size) | |||||
| memcpy(dx1, dy, dy_size * sizeof(float)); | |||||
| else | |||||
| ReduceSumByAxes(dy, arithmeticParameter_->out_shape_, dx1, arithmeticParameter_->in_shape0_, | |||||
| arithmeticParameter_->ndim_); | |||||
| if (dx2_size == dy_size) | |||||
| memcpy(dx2, dy, dy_size * sizeof(float)); | |||||
| else | |||||
| ReduceSumByAxes(dy, arithmeticParameter_->out_shape_, dx2, arithmeticParameter_->in_shape1_, | |||||
| arithmeticParameter_->ndim_); | |||||
| } | |||||
| void ArithmeticGradCPUKernel::ArithmeticGradSub(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, | |||||
| int dx2_size) { | |||||
| if (dx1_size == dy_size) | |||||
| memcpy(dx1, dy, dy_size * sizeof(float)); | |||||
| else | |||||
| ReduceSumByAxes(dy, arithmeticParameter_->out_shape_, dx1, arithmeticParameter_->in_shape0_, | |||||
| arithmeticParameter_->ndim_); | |||||
| if (dx2_size == dy_size) { | |||||
| for (int i = 0; i < dx2_size; i++) { | |||||
| dx2[i] = -dy[i]; | |||||
| } | |||||
| } else { | |||||
| ReduceSumByAxes(dy, arithmeticParameter_->out_shape_, dx2, arithmeticParameter_->in_shape1_, | |||||
| arithmeticParameter_->ndim_); | |||||
| for (int i = 0; i < dx2_size; i++) { | |||||
| dx2[i] = -dx2[i]; | |||||
| } | |||||
| } | |||||
| } | |||||
| void ArithmeticGradCPUKernel::ArithmeticGradMul(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, | |||||
| int dx2_size) { | |||||
| auto x1_data = reinterpret_cast<float *>(inputs_[1]->Data()); | |||||
| auto x2_data = reinterpret_cast<float *>(inputs_[2]->Data()); | |||||
| ElementMul(dy, x1_data, dx2, dy_size); | |||||
| ElementMul(dy, x2_data, dx1, dy_size); | |||||
| } | |||||
| void ArithmeticGradCPUKernel::ArithmeticGradMul1L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, | |||||
| int dx2_size) { | |||||
| auto x1_data = reinterpret_cast<float *>(inputs_[1]->Data()); | |||||
| auto x2_data = reinterpret_cast<float *>(inputs_[2]->Data()); | |||||
| ElementMul(dy, x1_data, tile_data0, dy_size); | |||||
| ReduceSumByAxes(tile_data0, arithmeticParameter_->in_shape0_, dx2, arithmeticParameter_->in_shape1_, | |||||
| arithmeticParameter_->ndim_); | |||||
| BroadcastMul(dy, x2_data, tile_data0, tile_data1, dx1, dy_size, arithmeticParameter_); // broadcast directly to dx1 | |||||
| } | |||||
| void ArithmeticGradCPUKernel::ArithmeticGradMul2L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, | |||||
| int dx2_size) { | |||||
| auto x1_data = reinterpret_cast<float *>(inputs_[1]->Data()); | |||||
| auto x2_data = reinterpret_cast<float *>(inputs_[2]->Data()); | |||||
| ElementMul(dy, x2_data, tile_data0, dy_size); | |||||
| ReduceSumByAxes(tile_data0, arithmeticParameter_->in_shape0_, dx1, arithmeticParameter_->in_shape1_, | |||||
| arithmeticParameter_->ndim_); | |||||
| BroadcastMul(dy, x1_data, tile_data0, tile_data1, dx2, dy_size, arithmeticParameter_); // broadcast directly to dx2 | |||||
| } | |||||
| void ArithmeticGradCPUKernel::ArithmeticGradDiv(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, | |||||
| int dx2_size) { | |||||
| auto x1 = reinterpret_cast<float *>(inputs_[1]->Data()); | |||||
| auto x2 = reinterpret_cast<float *>(inputs_[2]->Data()); | |||||
| ElementDiv(dy, x2, dx1, dy_size); | |||||
| ElementMulAndDivNegSquare(dy, x1, x2, dx2, dy_size); | |||||
| } | |||||
| void ArithmeticGradCPUKernel::ArithmeticGradDiv1L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, | |||||
| int dx2_size) { | |||||
| auto x1_data = reinterpret_cast<float *>(inputs_[1]->Data()); | |||||
| auto x2_data = reinterpret_cast<float *>(inputs_[2]->Data()); | |||||
| ElementMul(x2_data, x2_data, dx2, dx2_size); | |||||
| ElementMul(x1_data, dy, dx1, dy_size); // use dx1 buffer | |||||
| BroadcastDiv(dx1, dx2, tile_data0, tile_data1, tile_data2, dy_size, | |||||
| arithmeticParameter_); // broadcast directly to dx1 | |||||
| ReduceSumByAxes(tile_data2, arithmeticParameter_->in_shape0_, dx2, arithmeticParameter_->in_shape1_, | |||||
| arithmeticParameter_->ndim_); | |||||
| for (int i = 0; i < dx2_size; i++) dx2[i] = -dx2[i]; | |||||
| // ReduceNegSumPrefix(tile_data2, dy_size, dx2, dx2_size); //then reduce into dx2 | |||||
| // broadcasting x2 | |||||
| BroadcastDiv(dy, x2_data, tile_data0, tile_data1, dx1, dy_size, arithmeticParameter_); // broadcast directly to dx1 | |||||
| } | |||||
| void ArithmeticGradCPUKernel::ArithmeticGradDiv2L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, | |||||
| int dx2_size) { | |||||
| auto x1_data = reinterpret_cast<float *>(inputs_[1]->Data()); | |||||
| auto x2_data = reinterpret_cast<float *>(inputs_[2]->Data()); | |||||
| // dx1 = dy/x2 | |||||
| ElementDiv(dy, x2_data, tile_data0, dy_size); // first multiply into temp | |||||
| ReduceSumByAxes(tile_data0, arithmeticParameter_->in_shape0_, dx1, arithmeticParameter_->in_shape1_, | |||||
| arithmeticParameter_->ndim_); | |||||
| // dx2 = -dy*x1/(x2*x2) | |||||
| BroadcastMul(dy, x1_data, tile_data0, tile_data1, tile_data2, dy_size, arithmeticParameter_); // broadcast numerator | |||||
| ElementDivNegSquare(tile_data2, x2_data, dx2, dy_size); | |||||
| } | |||||
| int ArithmeticGradCPUKernel::ReSize() { return RET_OK; } | |||||
| int ArithmeticGradCPUKernel::Run() { | |||||
| auto dy = reinterpret_cast<float *>(inputs_[0]->Data()); | |||||
| // auto input1_data1 = reinterpret_cast<float *>(inputs_[1]->Data()); | |||||
| auto dx1 = reinterpret_cast<float *>(outputs_[0]->Data()); | |||||
| auto dx2 = reinterpret_cast<float *>(outputs_[1]->Data()); | |||||
| size_t dy_size = inputs_.at(0)->ElementsNum(); | |||||
| size_t dx1_size = outputs_.at(0)->ElementsNum(); | |||||
| size_t dx2_size = outputs_[1]->ElementsNum(); | |||||
| (this->*arithmetic_grad_)(dy, dy_size, dx1, dx1_size, dx2, dx2_size); | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuArithmeticGradFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::Context *ctx, | |||||
| const kernel::KernelKey &desc) { | |||||
| MS_EXCEPTION_IF_NULL(opParameter); | |||||
| if (opParameter == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto *kernel = new (std::nothrow) ArithmeticGradCPUKernel(opParameter, inputs, outputs); | |||||
| MS_ASSERT(kernel != nullptr); | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MulGrad, CpuArithmeticGradFp32KernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_AddGrad, CpuArithmeticGradFp32KernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SubGrad, CpuArithmeticGradFp32KernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DivGrad, CpuArithmeticGradFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,90 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "ir/anf.h" | |||||
| using mindspore::schema::PrimitiveType_AddGrad; | |||||
| using mindspore::schema::PrimitiveType_DivGrad; | |||||
| using mindspore::schema::PrimitiveType_MulGrad; | |||||
| using mindspore::schema::PrimitiveType_SubGrad; | |||||
| namespace mindspore::kernel { | |||||
| class ArithmeticGradCPUKernel; | |||||
| class ArithmeticGradCPUKernel : public LiteKernel { | |||||
| typedef void (ArithmeticGradCPUKernel::*ArithmeticGradOperation)(float *, int, float *, int, float *, int); | |||||
| public: | |||||
| explicit ArithmeticGradCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||||
| : LiteKernel(parameter, inputs, outputs), tile_data0(NULL), tile_data1(NULL), tile_data2(NULL) { | |||||
| switch (type()) { | |||||
| case PrimitiveType_MulGrad: | |||||
| arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul; // this will be adjusted in InferShape | |||||
| break; | |||||
| case PrimitiveType_AddGrad: | |||||
| arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradAdd; | |||||
| break; | |||||
| case PrimitiveType_SubGrad: | |||||
| arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradSub; | |||||
| break; | |||||
| case PrimitiveType_DivGrad: | |||||
| arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradDiv; // this will be adjusted in InferShape | |||||
| break; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Error Operator type " << parameter->type_; | |||||
| break; | |||||
| } | |||||
| arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(parameter); | |||||
| } | |||||
| ~ArithmeticGradCPUKernel() override { | |||||
| if (tile_data0) delete[] tile_data0; | |||||
| if (tile_data1) delete[] tile_data1; | |||||
| if (tile_data2) delete[] tile_data2; | |||||
| } | |||||
| void InitKernel(const CNodePtr &kernel_node); | |||||
| int Init() override; | |||||
| int InferShape(); | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| private: | |||||
| void ArithmeticGradAdd(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size); | |||||
| void ArithmeticGradSub(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size); | |||||
| void ArithmeticGradMul(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size); | |||||
| void ArithmeticGradMul1L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size); | |||||
| void ArithmeticGradMul2L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size); | |||||
| void ArithmeticGradDiv(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size); | |||||
| void ArithmeticGradDiv1L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size); | |||||
| void ArithmeticGradDiv2L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size); | |||||
| ArithmeticParameter *arithmeticParameter_; | |||||
| ArithmeticGradOperation arithmetic_grad_; | |||||
| float *tile_data0; | |||||
| float *tile_data1; | |||||
| float *tile_data2; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_GRAD_H_ | |||||
| @@ -0,0 +1,115 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include "src/runtime/kernel/arm/fp32/bias_grad.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::schema::PrimitiveType_BiasGrad; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| namespace mindspore::kernel { | |||||
| int BiasGradCPUKernel::InferShape() { | |||||
| if (1 != this->inputs_.size()) { | |||||
| MS_LOG(ERROR) << "BiasGrad should have one input"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (1 != this->outputs_.size()) { | |||||
| MS_LOG(ERROR) << "BiasGrad should have one output"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto *in0 = inputs_.front(); | |||||
| auto *out = outputs_.front(); | |||||
| MS_ASSERT(in0 != nullptr); | |||||
| MS_ASSERT(out != nullptr); | |||||
| auto inshape = in0->shape(); | |||||
| int ndim = inshape.size(); | |||||
| for (int i = 0; i < ndim - 1; i++) { | |||||
| inshape[i] = 1; | |||||
| } | |||||
| out->set_shape(inshape); | |||||
| out->set_data_type(in0->data_type()); | |||||
| return RET_OK; | |||||
| } | |||||
| int BiasGradCPUKernel::Init() { | |||||
| MS_ASSERT(InferShape() == RET_OK); | |||||
| auto dims = inputs_[0]->shape(); | |||||
| bias_param->ndim_ = dims.size(); | |||||
| for (unsigned int i = 0; i < bias_param->ndim_; i++) { | |||||
| bias_param->in_shape0_[i] = dims[i]; | |||||
| bias_param->out_shape_[i] = 1; // 1 dimension for N,H,W, | |||||
| } | |||||
| bias_param->out_shape_[bias_param->ndim_ - 1] = dims[bias_param->ndim_ - 1]; | |||||
| for (int i = bias_param->ndim_; i < 4; i++) { | |||||
| bias_param->in_shape0_[i] = 0; | |||||
| bias_param->out_shape_[i] = 0; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int BiasGradCPUKernel::ReSize() { return 0; } | |||||
| int BiasGradCPUKernel::Run() { | |||||
| auto in = reinterpret_cast<float *>(inputs_.at(0)->Data()); | |||||
| auto out = reinterpret_cast<float *>(outputs_.at(0)->Data()); | |||||
| // size_t data_size = inputs_.at(0)->ElementsNum(); | |||||
| size_t nhw_size = 1; | |||||
| size_t channels = bias_param->in_shape0_[bias_param->ndim_ - 1]; // C in NHWC | |||||
| for (unsigned int i = 0; i < bias_param->ndim_ - 1; i++) nhw_size *= bias_param->in_shape0_[i]; | |||||
| size_t total_size = channels * nhw_size; | |||||
| for (size_t c = 0; c < channels; ++c) { | |||||
| out[c] = 0; | |||||
| for (size_t offset = 0; offset < total_size; offset += channels) { | |||||
| out[c] += in[offset + c]; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuBiasGradFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::Context *ctx, | |||||
| const kernel::KernelKey &desc) { | |||||
| MS_ASSERT(opParameter != nullptr); | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_BiasGrad); | |||||
| auto *kernel = new (std::nothrow) BiasGradCPUKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||||
| MS_ASSERT(kernel != nullptr); | |||||
| auto ret = kernel->Init(); | |||||
| if (RET_OK != ret) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BiasGrad, CpuBiasGradFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_GRAD_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "ir/anf.h" | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h" | |||||
| namespace mindspore::kernel { | |||||
| class BiasGradCPUKernel : public LiteKernel { | |||||
| public: | |||||
| explicit BiasGradCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||||
| : LiteKernel(parameter, inputs, outputs) { | |||||
| bias_param = reinterpret_cast<ArithmeticParameter *>(parameter); | |||||
| } | |||||
| ~BiasGradCPUKernel() override = default; | |||||
| int Init() override; | |||||
| int InferShape(); | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| private: | |||||
| ArithmeticParameter *bias_param; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_GRAD_H_ | |||||
| @@ -0,0 +1,115 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <algorithm> | |||||
| #include <vector> | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_factory.h" | |||||
| #include "src/runtime/kernel/arm/fp32/bngrad_input.h" | |||||
| #include "src/runtime//kernel/arm/opclib/batch_norm.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| // using mindspore::lite::REG_OP; | |||||
| using mindspore::schema::PrimitiveType_BNGradInput; | |||||
| namespace mindspore::kernel { | |||||
| int BNGradInputCPUKernel::Init() { | |||||
| auto bn_param = reinterpret_cast<bnParameter *>(opParameter); | |||||
| workspace_size = 5 * bn_param->channels; | |||||
| workspace = new float[workspace_size]; | |||||
| if (2 != this->inputs_.size()) { | |||||
| MS_LOG(ERROR) << "Conv2d Grad should has 2 inputs"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (1 != this->outputs_.size()) { | |||||
| MS_LOG(ERROR) << "Conv2d Grad should has one output"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto *input_tensor = inputs_.at(0); | |||||
| // auto *weight_tensor = inputs_.at(1); | |||||
| auto *out_tensor = outputs_.at(0); | |||||
| auto in_shape = input_tensor->shape(); | |||||
| out_tensor->set_shape(in_shape); | |||||
| out_tensor->set_data_type(input_tensor->data_type()); | |||||
| return RET_OK; | |||||
| } | |||||
| int BNGradInputCPUKernel::ReSize() { return RET_OK; } | |||||
| /* | |||||
| according to https://wiseodd.github.io/techblog/2016/07/04/batchnorm | |||||
| */ | |||||
| int BNGradInputCPUKernel::Run() { | |||||
| // std::cout << "run succ" << std::endl; | |||||
| auto *input_x = inputs_.at(0); | |||||
| auto *input_yt = inputs_.at(1); | |||||
| auto *input_scale = inputs_.at(2); | |||||
| auto *output_grad = outputs_.at(0); | |||||
| // Tensor *bias = input[5]; | |||||
| auto bn_param = reinterpret_cast<bnParameter *>(opParameter); | |||||
| int batch = bn_param->batch; | |||||
| int channels = bn_param->channels; | |||||
| int spatial = bn_param->spatial; | |||||
| float eps = bn_param->eps; | |||||
| std::fill(workspace, workspace + workspace_size, 0.f); | |||||
| float *mean = workspace; | |||||
| float *variance = mean + channels; | |||||
| float *mean_delta = variance + channels; | |||||
| float *variance_delta = mean_delta + channels; | |||||
| float *mean_add_delta = variance_delta + channels; | |||||
| float *x = reinterpret_cast<float *>(input_x->Data()); | |||||
| float *yt = reinterpret_cast<float *>(input_yt->Data()); | |||||
| float *scale = reinterpret_cast<float *>(input_scale->Data()); | |||||
| float *out = reinterpret_cast<float *>(output_grad->Data()); | |||||
| std::copy(yt, yt + batch * channels * spatial, out); | |||||
| meanVar(x, batch, spatial, channels, mean, variance); | |||||
| scaleBias(scale, batch, channels, spatial, out); | |||||
| meanDelta(out, spatial, channels, eps, variance, mean_delta); | |||||
| varianceDelta(x, out, mean, variance, batch, channels, spatial, eps, variance_delta); | |||||
| meanAdd(x, mean, variance_delta, batch, channels, spatial, mean_add_delta, mean_delta); | |||||
| NormalizeDelta(x, mean, variance, mean_delta, variance_delta, batch, channels, eps, spatial, out); | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuBNGradInputFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::Context *ctx, | |||||
| const kernel::KernelKey &desc) { | |||||
| MS_ASSERT(opParameter != nullptr); | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_BNGradInput); | |||||
| // parameter->name = opDef.name()->str().data(); | |||||
| // parameter->type = opDef.attr_type(); | |||||
| auto *kernel = new (std::nothrow) BNGradInputCPUKernel(opParameter, inputs, outputs); | |||||
| MS_ASSERT(kernel != nullptr); | |||||
| auto ret = kernel->Init(); | |||||
| if (RET_OK != ret) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BNGradInput, CpuBNGradInputFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "ir/anf.h" | |||||
| namespace mindspore::kernel { | |||||
| class BNGradInputCPUKernel : public LiteKernel { | |||||
| public: | |||||
| explicit BNGradInputCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||||
| : LiteKernel(parameter, inputs, outputs) {} | |||||
| ~BNGradInputCPUKernel() override { delete workspace; } | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| private: | |||||
| float *workspace; | |||||
| int workspace_size; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BNGRAD_INPUT_H_ | |||||
| @@ -0,0 +1,156 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_grad_filter.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/kernel/arm/opclib/pack.h" | |||||
| #include "src/runtime/kernel/arm/opclib/pack_ext.h" | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/gemm.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Conv2DGradFilter; | |||||
| namespace mindspore::kernel { | |||||
| int ConvolutionGradFilterCPUKernel::Init() { | |||||
| // dy is in input 0 | |||||
| // x is in input 1 | |||||
| // dw is output 0 | |||||
| if (2 != this->inputs_.size()) { | |||||
| MS_LOG(ERROR) << "Conv2d Grad should has 2 inputs"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (1 != this->outputs_.size()) { | |||||
| MS_LOG(ERROR) << "Conv2d Grad should has one output"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto *input_tensor = inputs_.at(1); | |||||
| MS_ASSERT(input_tensor != nullptr); | |||||
| auto *dy = inputs_.at(0); | |||||
| MS_ASSERT(dy != nullptr); | |||||
| auto *weight_tensor = outputs_.at(0); | |||||
| MS_ASSERT(weight_tensor != nullptr); | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | |||||
| conv_param->output_batch_ = this->inputs_.at(0)->shape().at(kNHWC_N); | |||||
| conv_param->input_batch_ = this->inputs_.at(1)->shape().at(kNHWC_N); | |||||
| conv_param->input_h_ = this->inputs_.at(1)->shape().at(kNHWC_H); | |||||
| conv_param->input_w_ = this->inputs_.at(1)->shape().at(kNHWC_W); | |||||
| // assume OutCh|kh|kw|In | |||||
| conv_param->input_channel_ = this->inputs_.at(1)->shape().at(kNHWC_C); | |||||
| conv_param->output_channel_ = this->outputs_.at(0)->shape().at(kNHWC_N); | |||||
| int ws_size = conv_param->output_h_ * conv_param->output_w_ * conv_param->kernel_h_ * conv_param->kernel_w_ * | |||||
| conv_param->input_channel_ / conv_param->group_; | |||||
| workspace = new float[ws_size]; | |||||
| int output_w = 0; | |||||
| int output_h = 0; | |||||
| output_h = dy->shape()[kNHWC_H]; | |||||
| output_w = dy->shape()[kNHWC_W]; | |||||
| std::vector<int> out_shape(4); | |||||
| out_shape.at(0) = conv_param->output_channel_; | |||||
| out_shape.at(1) = conv_param->kernel_h_; | |||||
| out_shape.at(2) = conv_param->kernel_w_; | |||||
| out_shape.at(3) = conv_param->input_channel_ / conv_param->group_; | |||||
| // weight is output | |||||
| weight_tensor->set_shape(out_shape); | |||||
| weight_tensor->set_data_type(input_tensor->data_type()); | |||||
| conv_param->output_h_ = output_h; | |||||
| conv_param->output_w_ = output_w; | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionGradFilterCPUKernel::ReSize() { return 0; } | |||||
| int ConvolutionGradFilterCPUKernel::Run() { | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | |||||
| auto *input_dy = inputs_.at(0); | |||||
| auto *input_x = inputs_.at(1); | |||||
| auto *out_dw = outputs_.at(0); | |||||
| auto x_addr = reinterpret_cast<float *>(input_x->Data()); | |||||
| auto dy_addr = reinterpret_cast<float *>(input_dy->Data()); | |||||
| auto dw_addr = reinterpret_cast<float *>(out_dw->Data()); | |||||
| int i, j; | |||||
| int nweights = out_dw->ElementsNum(); | |||||
| int in_ch = conv_param->input_channel_; | |||||
| int in_h = conv_param->input_h_; | |||||
| int in_w = conv_param->input_w_; | |||||
| int k_h = conv_param->kernel_h_; // out_dw->shape()[1]; | |||||
| int k_w = conv_param->kernel_w_; // out_dw->shape()[2]; | |||||
| int batch = conv_param->output_batch_; | |||||
| int out_ch = conv_param->output_channel_; | |||||
| int groups = conv_param->group_; | |||||
| int out_h = conv_param->output_h_; | |||||
| int out_w = conv_param->output_w_; | |||||
| int m = out_h * out_w; | |||||
| int n = k_h * k_w * in_ch / groups; | |||||
| int k = out_ch / groups; | |||||
| // zero out pointer | |||||
| memset(dw_addr, 0, out_dw->Size()); | |||||
| for (i = 0; i < batch; ++i) { | |||||
| for (j = 0; j < groups; ++j) { | |||||
| float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups); | |||||
| float *mat_b = workspace; | |||||
| float *mat_c = dw_addr + j * nweights / groups; | |||||
| float *im = x_addr + (i * groups) * (in_ch / groups) * in_h * in_w + j * (in_ch / groups); | |||||
| im2row_hwc(im, mat_b, conv_param); | |||||
| gemm(1, 1, k, n, m, 1, mat_a, out_ch, mat_b, m, 1, mat_c, n); | |||||
| } | |||||
| } | |||||
| // std::cout << "run succ" << std::endl; | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuConvGradFilterFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::Context *ctx, | |||||
| const kernel::KernelKey &desc) { | |||||
| MS_ASSERT(opParameter != nullptr); | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_Conv2DGradFilter); | |||||
| auto *kernel = new (std::nothrow) ConvolutionGradFilterCPUKernel(opParameter, inputs, outputs); | |||||
| MS_ASSERT(kernel != nullptr); | |||||
| auto ret = kernel->Init(); | |||||
| if (RET_OK != ret) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Conv2DGradFilter, CpuConvGradFilterFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "ir/anf.h" | |||||
| namespace mindspore::kernel { | |||||
| class ConvolutionGradFilterCPUKernel : public LiteKernel { | |||||
| public: | |||||
| explicit ConvolutionGradFilterCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||||
| : LiteKernel(parameter, inputs, outputs) {} | |||||
| ~ConvolutionGradFilterCPUKernel() override { delete workspace; } | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| private: | |||||
| float *workspace; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_FILTER_H_ | |||||
| @@ -0,0 +1,136 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_grad_input.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/kernel/arm/opclib/pack.h" | |||||
| #include "src/runtime/kernel/arm/opclib/pack_ext.h" | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/gemm.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::schema::PrimitiveType_Conv2DGradInput; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| namespace mindspore::kernel { | |||||
| int ConvolutionGradInputCPUKernel::Init() { | |||||
| if (2 != this->inputs_.size()) { | |||||
| MS_LOG(ERROR) << "Conv2d Grad should has 2 inputs"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (1 != this->outputs_.size()) { | |||||
| MS_LOG(ERROR) << "Conv2d Grad should has one output"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto *dy_tensor = inputs_.at(kInputIndex); | |||||
| MS_ASSERT(dy_tensor != nullptr); | |||||
| auto *weight_tensor = inputs_.at(kWeightIndex); | |||||
| MS_ASSERT(weight_tensor != nullptr); | |||||
| auto *dx_tensor = outputs_.at(kOutputIndex); | |||||
| MS_ASSERT(dx_tensor != nullptr); | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | |||||
| conv_param->output_batch_ = dx_tensor->shape()[(kNHWC_N)]; | |||||
| conv_param->input_batch_ = dy_tensor->shape()[(kNHWC_N)]; | |||||
| conv_param->input_h_ = dx_tensor->shape()[(kNHWC_H)]; | |||||
| conv_param->input_w_ = dx_tensor->shape()[(kNHWC_W)]; | |||||
| // assume OutCh|kh|kw|In | |||||
| conv_param->input_channel_ = dx_tensor->shape()[(kNHWC_C)]; | |||||
| conv_param->output_channel_ = weight_tensor->shape()[(kNHWC_N)]; | |||||
| // TBD | |||||
| conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; | |||||
| conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; | |||||
| int ws_size = conv_param->output_h_ * conv_param->output_w_ * conv_param->kernel_h_ * conv_param->kernel_w_ * | |||||
| conv_param->input_channel_ / conv_param->group_; | |||||
| workspace = new float[ws_size]; | |||||
| return 0; | |||||
| } | |||||
| int ConvolutionGradInputCPUKernel::ReSize() { return 0; } | |||||
| int ConvolutionGradInputCPUKernel::Run() { | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | |||||
| auto *input_dy = inputs_.at(0); | |||||
| auto *input_w = inputs_.at(1); | |||||
| auto *out_dx = outputs_.at(0); | |||||
| auto dy_addr = reinterpret_cast<float *>(input_dy->Data()); | |||||
| auto w_addr = reinterpret_cast<float *>(input_w->Data()); | |||||
| auto dx_addr = reinterpret_cast<float *>(out_dx->Data()); | |||||
| int i, j; | |||||
| int nweights = input_w->ElementsNum(); | |||||
| int in_ch = conv_param->input_channel_; | |||||
| int in_h = conv_param->input_h_; | |||||
| int in_w = conv_param->input_w_; | |||||
| int k_h = conv_param->kernel_h_; // out_dw->shape()[1]; | |||||
| int k_w = conv_param->kernel_w_; // out_dw->shape()[2]; | |||||
| int batch = conv_param->output_batch_; | |||||
| int out_ch = conv_param->output_channel_; | |||||
| int groups = conv_param->group_; | |||||
| int out_h = conv_param->output_h_; | |||||
| int out_w = conv_param->output_w_; | |||||
| int m = out_h * out_w; | |||||
| int n = k_w * k_h * in_ch / groups; | |||||
| int k = out_ch / groups; | |||||
| memset(dx_addr, 0, sizeof(float) * batch * in_ch * in_h * in_w); | |||||
| for (i = 0; i < batch; ++i) { | |||||
| for (j = 0; j < groups; ++j) { | |||||
| float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups); | |||||
| float *mat_b = w_addr + j * nweights / groups; | |||||
| float *mat_c = workspace; | |||||
| gemm(0, 0, m, n, k, 1, mat_a, out_ch, mat_b, n, 0, mat_c, n); | |||||
| col2im_hwc(mat_c, dx_addr + (i * groups) * (in_ch / groups) * in_h * in_w + j * (in_ch / groups), conv_param); | |||||
| } | |||||
| } | |||||
| // std::cout << "run succ" << std::endl; | |||||
| return 0; | |||||
| } | |||||
| kernel::LiteKernel *CpuConvGradInputFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::Context *ctx, | |||||
| const kernel::KernelKey &desc) { | |||||
| MS_ASSERT(opParameter != nullptr); | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_Conv2DGradInput); | |||||
| auto *kernel = new (std::nothrow) ConvolutionGradInputCPUKernel(opParameter, inputs, outputs); | |||||
| MS_ASSERT(kernel != nullptr); | |||||
| auto ret = kernel->Init(); | |||||
| if (0 != ret) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Conv2DGradInput, CpuConvGradInputFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "ir/anf.h" | |||||
| namespace mindspore::kernel { | |||||
| class ConvolutionGradInputCPUKernel : public LiteKernel { | |||||
| public: | |||||
| explicit ConvolutionGradInputCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||||
| : LiteKernel(parameter, inputs, outputs) {} | |||||
| ~ConvolutionGradInputCPUKernel() override { delete workspace; } | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| private: | |||||
| float *workspace; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_GRAD_INPUT_H_ | |||||
| @@ -0,0 +1,78 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/kernel/arm/fp32/opt_momentum.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::schema::PrimitiveType_OptMomentum; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| namespace mindspore::kernel { | |||||
| int OptMomentumCPUKernel::ReSize() { return 0; } | |||||
| int OptMomentumCPUKernel::Run() { | |||||
| if (inputs_.size() != 5 || !outputs_.empty()) { | |||||
| MS_LOG(ERROR) << "OptMomentumCPUKernel error input output size!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (inputs_[0]->ElementsNum() != inputs_[1]->ElementsNum() || | |||||
| inputs_[0]->ElementsNum() != inputs_[3]->ElementsNum()) { | |||||
| MS_LOG(ERROR) << "error input data size!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto weight = reinterpret_cast<float *>(inputs_[0]->Data()); | |||||
| auto accumulate = reinterpret_cast<float *>(inputs_[1]->Data()); | |||||
| float learning_rate = reinterpret_cast<float *>(inputs_[2]->Data())[0]; | |||||
| auto gradient = reinterpret_cast<float *>(inputs_[3]->Data()); | |||||
| float moment = reinterpret_cast<float *>(inputs_[4]->Data())[0]; | |||||
| size_t elem_num = inputs_[0]->ElementsNum(); | |||||
| for (size_t i = 0; i < elem_num; ++i) { | |||||
| accumulate[i] = accumulate[i] * moment + gradient[i]; | |||||
| weight[i] -= accumulate[i] * learning_rate; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int OptMomentumCPUKernel::Init() { return 0; } | |||||
| kernel::LiteKernel *CpuOptMomentumFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::Context *ctx, | |||||
| const kernel::KernelKey &desc) { | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_OptMomentum); | |||||
| auto *kernel = new (std::nothrow) OptMomentumCPUKernel(opParameter, inputs, outputs); | |||||
| MS_ASSERT(kernel != nullptr); | |||||
| auto ret = kernel->Init(); | |||||
| if (0 != ret) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_OptMomentum, CpuOptMomentumFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_FP32_OPT_MOMENTUM_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPT_MOMENTUM_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "ir/anf.h" | |||||
| namespace mindspore::kernel { | |||||
| class OptMomentumCPUKernel : public LiteKernel { | |||||
| public: | |||||
| explicit OptMomentumCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||||
| : LiteKernel(parameter, inputs, outputs) {} | |||||
| ~OptMomentumCPUKernel() override {} | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| private: | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPT_MOMENTUM_H_ | |||||
| @@ -0,0 +1,195 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32/pooling_grad.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/pooling.h" | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/pooling_grad.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_PoolingGrad; | |||||
| namespace mindspore::kernel { | |||||
| #if 0 | |||||
| int PoolingGradCPUKernel::TfPadding(int input_w, int input_h, int &output_w, int &output_h) { | |||||
| PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *> (opParameter); | |||||
| auto stride_w = pool_param->stride_w_; | |||||
| auto stride_h = pool_param->stride_h_; | |||||
| auto window_w = pool_param->window_w_; | |||||
| auto window_h = pool_param->window_h_; | |||||
| auto pad_up = pool_param->pad_u_; | |||||
| auto pad_down = pool_param->pad_d_; | |||||
| auto pad_left = pool_param->pad_l_; | |||||
| auto pad_right = pool_param->pad_r_; | |||||
| if (pool_param->pad_mode_ == PADMODE_SAME) { | |||||
| output_w = ceil(input_w / stride_w); | |||||
| output_h = ceil(input_h / stride_h); | |||||
| } else { | |||||
| output_w = ceil((input_w + pad_left + pad_right - window_w + 1) / stride_w); | |||||
| output_h = ceil((input_h + pad_up + pad_down - window_h + 1) / stride_h); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int PoolingGradCPUKernel::CaffePadding(int input_w, int input_h, int &output_w, int &output_h) { | |||||
| PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *> (opParameter); | |||||
| auto round_mode = pool_param->round_mode_; | |||||
| auto stride_w = pool_param->stride_w_; | |||||
| auto stride_h = pool_param->stride_h_; | |||||
| auto window_w = pool_param->window_w_; | |||||
| auto window_h = pool_param->window_h_; | |||||
| auto pad_up = pool_param->pad_u_; | |||||
| auto pad_down = pool_param->pad_d_; | |||||
| auto pad_left = pool_param->pad_l_; | |||||
| auto pad_right = pool_param->pad_r_; | |||||
| if (round_mode == ROUNDMODE_FLOOR && false) { | |||||
| output_w = floor((input_w + pad_left + pad_right - window_w) / stride_w + 1); | |||||
| output_h = floor((input_h + pad_up + pad_down - window_h) / stride_h + 1); | |||||
| } else if (round_mode == ROUNDMODE_CEIL || true) { | |||||
| output_w = ceil((input_w + pad_left + pad_right - window_w) / stride_w + 1); | |||||
| output_h = ceil((input_h + pad_up + pad_down - window_h) / stride_h + 1); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "round mode not support."; | |||||
| } | |||||
| if (pad_left > 0 || pad_up > 0) { | |||||
| if ((output_w - 1) * stride_w >= input_w + pad_left) { | |||||
| --output_w; | |||||
| } | |||||
| if ((output_h - 1) * stride_h >= input_h + pad_up) { | |||||
| --output_h; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int PoolingGradCPUKernel::OnnxPadding(int input_w, int input_h, int &output_w, int &output_h) { | |||||
| PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *> (opParameter); | |||||
| auto round_mode = pool_param->round_mode_; | |||||
| auto stride_w = pool_param->stride_w_; | |||||
| auto stride_h = pool_param->stride_h_; | |||||
| auto window_w = pool_param->window_w_; | |||||
| auto window_h = pool_param->window_h_; | |||||
| auto pad_up = pool_param->pad_u_; | |||||
| auto pad_down = pool_param->pad_d_; | |||||
| auto pad_left = pool_param->pad_l_; | |||||
| auto pad_right = pool_param->pad_r_; | |||||
| if (round_mode == ROUNDMODE_FLOOR) { | |||||
| output_w = floor((input_w + pad_left + pad_right - window_w) / stride_w + 1); | |||||
| output_h = floor((input_h + pad_up + pad_down - window_h) / stride_h + 1); | |||||
| } else if (round_mode == ROUNDMODE_CEIL) { | |||||
| MS_LOG(ERROR) << "RoundMode_CEIL mode not support."; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "OnnxPadding round mode not support."; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #endif | |||||
| int PoolingGradCPUKernel::Init() { | |||||
| // InferShape(): | |||||
| // auto *in_tensor = reinterpret_cast<float *>(inputs_.at(0)->Data()); | |||||
| // auto *x_tensor = reinterpret_cast<float *>(inputs_.at(1)->Data()); | |||||
| PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(opParameter); | |||||
| auto in_shape = inputs_.at(0)->shape(); | |||||
| int input_h = in_shape.at(1); | |||||
| int input_w = in_shape.at(2); | |||||
| if (pool_param->global_) { | |||||
| pool_param->window_w_ = input_w; | |||||
| pool_param->window_h_ = input_h; | |||||
| } | |||||
| // Emir -- here I assume we get the outputshape in the output tensor | |||||
| auto *out_tensor = outputs_.front(); | |||||
| auto out_shape = out_tensor->shape(); | |||||
| #if 0 | |||||
| int output_w = 0, output_h = 0; | |||||
| auto fmk_type = pool_param->fmk_type_; | |||||
| switch (fmk_type) { | |||||
| case lite::FmkType_TF: | |||||
| break; | |||||
| case lite::FmkType_CAFFE: | |||||
| CaffePadding(input_w, input_h, output_w, output_h); | |||||
| break; | |||||
| case lite::FmkType_ONNX: | |||||
| OnnxPadding(input_w, input_h, output_w, output_h); | |||||
| break; | |||||
| case lite::FmkType_MS: | |||||
| break; | |||||
| case lite::FmkType_TFLITE: | |||||
| TfPadding(input_w, input_h, output_w, output_h); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Not support this framework."; | |||||
| } | |||||
| std::vector<int> out_shape{in_tensor->shape()}; | |||||
| out_shape.at(1) = output_h; | |||||
| out_shape.at(2) = output_w; | |||||
| #endif | |||||
| out_tensor->set_shape(out_shape); | |||||
| out_tensor->set_data_type(inputs_.at(0)->data_type()); | |||||
| return RET_OK; | |||||
| } | |||||
| int PoolingGradCPUKernel::ReSize() { return RET_OK; } | |||||
| int PoolingGradCPUKernel::Run() { | |||||
| PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(opParameter); | |||||
| auto input_ptr = reinterpret_cast<float *>(inputs_.at(0)->Data()); | |||||
| auto output_ptr = reinterpret_cast<float *>(outputs_.at(0)->Data()); | |||||
| if (pool_param->max_pooling_) { | |||||
| auto ind = reinterpret_cast<int *>(inputs_.at(1)->Data()); | |||||
| MaxPoolingGrad(input_ptr, ind, output_ptr, pool_param); | |||||
| } else { | |||||
| AvgPoolingGrad(input_ptr, output_ptr, pool_param); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuPoolingGradFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::Context *ctx, | |||||
| const kernel::KernelKey &desc) { | |||||
| MS_ASSERT(opParameter != nullptr); | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_PoolingGrad); | |||||
| auto *kernel = new (std::nothrow) PoolingGradCPUKernel(opParameter, inputs, outputs); | |||||
| MS_ASSERT(kernel != nullptr); | |||||
| auto ret = kernel->Init(); | |||||
| if (RET_OK != ret) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_PoolingGrad, CpuPoolingGradFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,50 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "ir/anf.h" | |||||
| namespace mindspore::kernel { | |||||
| using mindspore::schema::PadMode; | |||||
| using mindspore::schema::PoolMode; | |||||
| using mindspore::schema::QuantType; | |||||
| using mindspore::schema::RoundMode; | |||||
| class PoolingGradCPUKernel : public LiteKernel { | |||||
| public: | |||||
| explicit PoolingGradCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||||
| : LiteKernel(parameter, inputs, outputs) {} | |||||
| ~PoolingGradCPUKernel() override = default; | |||||
| // int TfPadding(int input_w, int input_h, int &output_w, int &output_h); | |||||
| // int CaffePadding(int input_w, int input_h, int &output_w, int &output_h); | |||||
| // int OnnxPadding(int input_w, int input_h, int &output_w, int &output_h); | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| private: | |||||
| uint8_t data_shape_{0}; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_GRAD_H_ | |||||
| @@ -0,0 +1,67 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32/power_grad.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h" | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_PowerGrad; | |||||
| namespace mindspore::kernel { | |||||
| int PowerGradCPUKernel::Init() { return RET_OK; } | |||||
| int PowerGradCPUKernel::ReSize() { return RET_OK; } | |||||
| int PowerGradCPUKernel::Run() { | |||||
| auto dy_addr = reinterpret_cast<float *>(inputs_.at(0)->Data()); | |||||
| auto x_addr = reinterpret_cast<float *>(inputs_.at(1)->Data()); | |||||
| auto dx_addr = reinterpret_cast<float *>(outputs_.at(0)->Data()); | |||||
| auto size = inputs_.at(0)->ElementsNum(); | |||||
| Power(x_addr, dx_addr, size, power_ - 1, scale_, shift_); | |||||
| ElementMul(dx_addr, dy_addr, dx_addr, size); | |||||
| float scale = scale_ * power_; | |||||
| for (int i = 0; i < size; i++) { | |||||
| dx_addr[i] *= scale; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuPowerGradFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::Context *ctx, | |||||
| const kernel::KernelKey &desc) { | |||||
| MS_ASSERT(opParameter != nullptr); | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_PowerGrad); | |||||
| auto *kernel = new (std::nothrow) PowerGradCPUKernel(opParameter, inputs, outputs); | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_PowerGrad, CpuPowerGradFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,49 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "ir/anf.h" | |||||
| #include "src/runtime/kernel/arm/opclib/power.h" | |||||
| namespace mindspore::kernel { | |||||
| class PowerGradCPUKernel : public LiteKernel { | |||||
| public: | |||||
| PowerGradCPUKernel(OpParameter *param, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||||
| : LiteKernel(param, inputs, outputs) { | |||||
| PowerParameter *power_param = reinterpret_cast<PowerParameter *>(param); | |||||
| power_ = power_param->power_; | |||||
| scale_ = power_param->scale_; | |||||
| shift_ = power_param->shift_; | |||||
| } | |||||
| ~PowerGradCPUKernel() override = default; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| private: | |||||
| float power_; | |||||
| float scale_; | |||||
| float shift_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_GRAD_H_ | |||||
| @@ -0,0 +1,145 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.h" | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/softmax.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_SoftmaxCrossEntropy; | |||||
| namespace mindspore::kernel { | |||||
| int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { return RET_OK; } | |||||
| void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int *labels, const float *losses, | |||||
| float *output) const { | |||||
| float total_loss = 0; | |||||
| for (int i = 0; i < param->batch_size_; ++i) { | |||||
| if (labels[i] < 0) { | |||||
| MS_LOG(EXCEPTION) << "label value must >= 0"; | |||||
| } | |||||
| size_t label = labels[i]; | |||||
| if (label > param->number_of_classes_) { | |||||
| MS_LOG(EXCEPTION) << "error label input!"; | |||||
| } else { | |||||
| total_loss -= logf(losses[i * param->number_of_classes_ + label]); | |||||
| } | |||||
| } | |||||
| output[0] = total_loss / param->batch_size_; | |||||
| } | |||||
| void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *labels, const float *losses, | |||||
| float *output) const { | |||||
| size_t row_start = 0; | |||||
| for (int i = 0; i < param->batch_size_; ++i) { | |||||
| if (labels[i] < 0) { | |||||
| MS_LOG(EXCEPTION) << "label value must >= 0"; | |||||
| } | |||||
| size_t label = labels[i]; | |||||
| if (label > param->number_of_classes_) { | |||||
| MS_LOG(EXCEPTION) << "error label input!"; | |||||
| } | |||||
| for (size_t j = 0; j < param->number_of_classes_; ++j) { | |||||
| size_t index = row_start + j; | |||||
| if (j == label) { | |||||
| output[index] = (losses[index] - 1) / param->batch_size_; | |||||
| } else { | |||||
| output[index] = losses[index] / param->batch_size_; | |||||
| } | |||||
| } | |||||
| row_start += param->number_of_classes_; | |||||
| } | |||||
| } | |||||
| int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { | |||||
| auto ins = reinterpret_cast<float *>(inputs_.at(0)->Data()); | |||||
| auto labels = reinterpret_cast<int *>(inputs_.at(1)->Data()); | |||||
| auto out = reinterpret_cast<float *>(outputs_.at(0)->Data()); | |||||
| float *grads = NULL; | |||||
| if (is_train()) { // outputs_.size() > 1) | |||||
| grads = reinterpret_cast<float *>(outputs_.at(0)->Data()); | |||||
| } | |||||
| size_t data_size = inputs_.at(0)->ElementsNum(); | |||||
| float *losses = new (std::nothrow) float[data_size]; | |||||
| MS_ASSERT(losses != nullptr); | |||||
| std::fill(losses, losses + data_size, 0); | |||||
| MS_ASSERT(out != nullptr); | |||||
| MS_ASSERT(labels != nullptr); | |||||
| MS_ASSERT(ins != nullptr); | |||||
| SoftmaxParameter sm_params; | |||||
| sm_params.n_dim_ = param->n_dim_; | |||||
| sm_params.element_size_ = data_size; | |||||
| sm_params.axis_ = 1; | |||||
| for (int i = 0; i < 4; i++) // softmax has only 4 params in shape | |||||
| sm_params.input_shape_[i] = param->input_shape_[i]; | |||||
| float sum_data[sm_params.input_shape_[sm_params.axis_]]; | |||||
| Softmax(ins, losses, sum_data, &sm_params); | |||||
| if (is_train()) { | |||||
| GradPostExecute(labels, losses, grads); | |||||
| } else { | |||||
| ForwardPostExecute(labels, losses, out); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Init() { | |||||
| auto dims = inputs_[0]->shape(); | |||||
| param->n_dim_ = 2; | |||||
| param->number_of_classes_ = dims[1]; | |||||
| param->batch_size_ = dims[0]; | |||||
| for (unsigned int i = 0; i < dims.size(); i++) param->input_shape_[i] = dims[i]; | |||||
| if (2 != this->inputs_.size()) { | |||||
| MS_LOG(ERROR) << "softmax entropy loss should have two inputs"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto *in0 = inputs_.front(); | |||||
| if (in0 == nullptr) { | |||||
| MS_LOG(ERROR) << "softmax etropy loss in0 have no data"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuSoftmaxCrossEntropyFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::Context *ctx, | |||||
| const kernel::KernelKey &desc) { | |||||
| MS_ASSERT(opParameter != nullptr); | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_SoftmaxCrossEntropy); | |||||
| auto *kernel = new (std::nothrow) SparseSoftmaxCrossEntropyWithLogitsCPUKernel(opParameter, inputs, outputs); | |||||
| MS_ASSERT(kernel != nullptr); | |||||
| auto ret = kernel->Init(); | |||||
| if (RET_OK != ret) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SoftmaxCrossEntropy, CpuSoftmaxCrossEntropyFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,50 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "ir/anf.h" | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/softmax_grad.h" | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h" | |||||
| namespace mindspore::kernel { | |||||
| class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LiteKernel { | |||||
| public: | |||||
| explicit SparseSoftmaxCrossEntropyWithLogitsCPUKernel(OpParameter *parameter, | |||||
| const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||||
| : LiteKernel(parameter, inputs, outputs) { | |||||
| param = reinterpret_cast<SoftmaxCrossEntropyParameter *>(parameter); | |||||
| } | |||||
| ~SparseSoftmaxCrossEntropyWithLogitsCPUKernel() override = default; | |||||
| void ForwardPostExecute(const int *labels, const float *losses, float *output) const; | |||||
| void GradPostExecute(const int *labels, const float *losses, float *output) const; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| private: | |||||
| SoftmaxCrossEntropyParameter *param; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ | |||||
| @@ -0,0 +1,88 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ACTIVATION_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ACTIVATION_GRAD_H_ | |||||
| #include <math.h> | |||||
| #include "src/runtime/kernel/arm/opclib/op_base.h" | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h" | |||||
| #include "src/runtime/kernel/arm/opclib/errorcode.h" | |||||
| struct ActivationGradParameter { | |||||
| OpParameter op_parameter{}; | |||||
| int type_; | |||||
| float alpha_{0.01}; | |||||
| }; | |||||
| inline int ReluGrad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| dst[i] = src1[i] > 0 ? 1.0f : 0.0f; | |||||
| } | |||||
| ElementMul(src0, dst, dst, length); | |||||
| return OPCLIB_OK; | |||||
| } | |||||
| inline int Relu6Grad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| if (src1[i] < 0) { | |||||
| dst[i] = 0; | |||||
| } else { | |||||
| dst[i] = src1[i] > 6.0f ? 0.0f : 1.0f; | |||||
| } | |||||
| } | |||||
| ElementMul(src0, dst, dst, length); | |||||
| return OPCLIB_OK; | |||||
| } | |||||
| inline int LReluGrad(float *src0, float *src1, int length, float *dst, float alpha) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| dst[i] = src1[i] > 0.0f ? 1.0f : alpha; | |||||
| } | |||||
| ElementMul(src0, dst, dst, length); | |||||
| return OPCLIB_OK; | |||||
| } | |||||
| inline int SigmoidGrad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| dst[i] = src0[i] * (src1[i] * (1.0f - src1[i])); | |||||
| } | |||||
| return OPCLIB_OK; | |||||
| } | |||||
| inline int TanhGrad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| dst[i] = (1.0f - (src1[i] * src1[i])) * src0[i]; | |||||
| } | |||||
| return OPCLIB_OK; | |||||
| } | |||||
| inline int HSwishGrad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : (2.0f * src1[i] + 3.0f) / 6.0f)); | |||||
| dst[i] = tmp * src0[i]; | |||||
| } | |||||
| return OPCLIB_OK; | |||||
| } | |||||
| inline int HSigmoidGrad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : 1.0f / 6.0f)); | |||||
| dst[i] = tmp * src0[i]; | |||||
| } | |||||
| return OPCLIB_OK; | |||||
| } | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ACTIVATION_GRAD_H_ | |||||
| @@ -0,0 +1,120 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <algorithm> | |||||
| #include <cmath> | |||||
| #include "src/runtime/kernel/arm/opclib/batch_norm.h" | |||||
| static void sumSpatialBatch(const float *in, int size, int ch, float *out) { | |||||
| std::fill(out, out + ch, 0.f); | |||||
| for (int i = 0; i < size; i++) { | |||||
| const float *ptr = in + i * ch; | |||||
| for (int c = 0; c < ch; c++) { | |||||
| out[c] += ptr[c]; | |||||
| } | |||||
| } | |||||
| } | |||||
| void scaleBias(const float *scales, int batch, int n, int size, float *output) { | |||||
| for (int i = 0; i < batch * size; i++) | |||||
| for (int c = 0; c < n; c++) output[i * n + c] *= scales[c]; | |||||
| } | |||||
| void normalize(const float *x, const float *mean, const float *variance, float eps, int batch, int filters, int spatial, | |||||
| float *out) { | |||||
| int b, f, i; | |||||
| for (b = 0; b < batch; ++b) { | |||||
| for (i = 0; i < spatial; ++i) { | |||||
| for (f = 0; f < filters; ++f) { | |||||
| int index = b * filters * spatial + i * filters + f; | |||||
| out[index] = (x[index] - mean[f]) / (std::sqrt(variance[f]) + eps); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void backwardScale(const float *x_norm, const float *delta, int batch, int n, int size, float *scale_updates) { | |||||
| int i, b, f; | |||||
| std::fill(scale_updates, scale_updates + n, 0.f); | |||||
| for (b = 0; b < batch; ++b) { | |||||
| for (i = 0; i < size; ++i) { | |||||
| for (f = 0; f < n; ++f) { | |||||
| int index = (b * size + i) * n + f; | |||||
| scale_updates[f] += delta[index] * x_norm[index]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void meanVar(const float *in, int batch, int spatial, int ch, float *mean, float *var) { | |||||
| float N = batch * spatial; | |||||
| sumSpatialBatch(in, N, ch, mean); | |||||
| for (int f = 0; f < ch; ++f) mean[f] /= N; | |||||
| std::fill(var, var + ch, 0.f); | |||||
| for (int i = 0; i < N; i++) { | |||||
| for (int f = 0; f < ch; f++) { | |||||
| float x = in[i * ch + f]; | |||||
| var[f] += (x - mean[f]) * (x - mean[f]); | |||||
| } | |||||
| } | |||||
| for (int f = 0; f < ch; f++) var[f] /= N; | |||||
| } | |||||
| void meanDelta(float *yt, int size, int ch, float eps, float *variance, float *mean_delta) { | |||||
| sumSpatialBatch(yt, size, ch, mean_delta); | |||||
| for (int i = 0; i < ch; i++) mean_delta[i] *= -1.f / std::sqrt((variance[i] + eps)); | |||||
| } | |||||
| void meanAdd(const float *x, const float *mean, const float *variance_delta, int batch, int filters, int spatial, | |||||
| float *mean_add, float *mean_delta) { | |||||
| int i, k; | |||||
| std::fill(mean_add, mean_add + filters, 0.f); | |||||
| for (k = 0; k < spatial * batch; ++k) { | |||||
| for (i = 0; i < filters; ++i) { | |||||
| int index = k * filters + i; | |||||
| mean_add[i] += x[index] - mean[i]; | |||||
| } | |||||
| } | |||||
| for (i = 0; i < filters; ++i) { | |||||
| mean_add[i] *= variance_delta[i] * (-2.f / (spatial * batch)); | |||||
| mean_delta[i] += mean_add[i]; | |||||
| } | |||||
| } | |||||
| void varianceDelta(const float *x, const float *delta, const float *mean, const float *variance, int batch, int filters, | |||||
| int spatial, float eps, float *variance_delta) { | |||||
| int i, k; | |||||
| std::fill(variance_delta, variance_delta + filters, 0.f); | |||||
| for (k = 0; k < batch * spatial; k++) { | |||||
| for (i = 0; i < filters; i++) { | |||||
| int index = k * filters + i; | |||||
| variance_delta[i] += delta[index] * (x[index] - mean[i]); | |||||
| } | |||||
| } | |||||
| for (i = 0; i < filters; i++) variance_delta[i] *= -.5 * pow(variance[i] + eps, (-3.f / 2.f)); | |||||
| } | |||||
| void NormalizeDelta(const float *x, const float *mean, const float *variance, const float *mean_delta, | |||||
| const float *variance_delta, int batch, int filters, int spatial, float eps, float *delta) { | |||||
| int f, k; | |||||
| for (k = 0; k < batch * spatial; k++) { | |||||
| for (f = 0; f < filters; f++) { | |||||
| int index = k * filters + f; | |||||
| delta[index] = delta[index] * 1. / (std::sqrt(variance[f] + eps)) + | |||||
| variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + | |||||
| mean_delta[f] / (spatial * batch); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,39 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_BATCH_NORM_H_ | |||||
| #define MINDSPORE_LITE_SRC_BACKEND_ARM_BATCH_NORM_H_ | |||||
| struct bnParameter { | |||||
| int batch; | |||||
| int channels; | |||||
| int spatial; | |||||
| float eps; | |||||
| }; | |||||
| void scaleBias(const float *scales, int batch, int n, int size, float *output); | |||||
| void normalize(const float *x, const float *mean, const float *variance, float eps, int batch, int filters, int spatial, | |||||
| float *out); | |||||
| void backwardScale(const float *x_norm, const float *delta, int batch, int n, int size, float *scale_updates); | |||||
| void meanVar(const float *in, int batch, int size, int ch, float *mean, float *var); | |||||
| void meanDelta(float *yt, int size, int ch, float eps, float *variance, float *mean_delta); | |||||
| void varianceDelta(const float *x, const float *delta, const float *mean, const float *variance, int batch, int ch, | |||||
| int spatial, float eps, float *variance_delta); | |||||
| void meanAdd(const float *x, const float *mean, const float *variance_delta, int batch, int filters, int spatial, | |||||
| float *mean_add, float *mean_delta); | |||||
| void NormalizeDelta(const float *x, const float *mean, const float *variance, const float *mean_delta, | |||||
| const float *variance_delta, int batch, int filters, int spatial, float eps, float *delta); | |||||
| #endif | |||||
| @@ -0,0 +1,29 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/arithmetic_grad.h" | |||||
| void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size) { | |||||
| for (int i = 0; i < element_size; i++) { | |||||
| output[i] = -nom[i] / (denom[i] * denom[i]); | |||||
| } | |||||
| } | |||||
| void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size) { | |||||
| for (int i = 0; i < element_size; i++) { | |||||
| output[i] = -a[i] * b[i] / (denom[i] * denom[i]); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,22 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_ARITHMETIC_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_ARITHMETIC_GRAD_H_ | |||||
| void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size); | |||||
| void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size); | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_ARITHMETIC_GRAD_H_ | |||||
| @@ -0,0 +1,108 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/gemm.h" | |||||
| static void gemm_nn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_B, int ldb, float *mat_c, | |||||
| int ldc) { | |||||
| int i, j, k; | |||||
| for (i = 0; i < M; ++i) { | |||||
| for (k = 0; k < K; ++k) { | |||||
| float a = alpha * mat_a[i * lda + k]; | |||||
| for (j = 0; j < N; ++j) { | |||||
| mat_c[i * ldc + j] += a * mat_B[k * ldb + j]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| static void gemm_nt(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c, | |||||
| int ldc) { | |||||
| int i, j, k; | |||||
| for (i = 0; i < M; ++i) { | |||||
| for (j = 0; j < N; ++j) { | |||||
| float sum = 0; | |||||
| for (k = 0; k < K; ++k) { | |||||
| sum += alpha * mat_a[i * lda + k] * mat_b[j * ldb + k]; | |||||
| } | |||||
| mat_c[i * ldc + j] += sum; | |||||
| } | |||||
| } | |||||
| } | |||||
| static void gemm_tn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c, | |||||
| int ldc) { | |||||
| int i, j, k; | |||||
| for (i = 0; i < M; ++i) { | |||||
| for (k = 0; k < K; ++k) { | |||||
| float a = alpha * mat_a[k * lda + i]; | |||||
| for (j = 0; j < N; ++j) { | |||||
| mat_c[i * ldc + j] += a * mat_b[k * ldb + j]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| static void gemm_tt(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c, | |||||
| int ldc) { | |||||
| int i, j, k; | |||||
| for (i = 0; i < M; ++i) { | |||||
| for (j = 0; j < N; ++j) { | |||||
| float sum = 0; | |||||
| for (k = 0; k < K; ++k) { | |||||
| sum += alpha * mat_a[i + k * lda] * mat_b[k + j * ldb]; | |||||
| } | |||||
| mat_c[i * ldc + j] += sum; | |||||
| } | |||||
| } | |||||
| } | |||||
| // mat_c = alpha*op( mat_a )*op( mat_b ) + beta*C | |||||
| // M - number of rows of matrix a | |||||
| // N - number of cols of matrix b | |||||
| // K - number of cols of matrix a | |||||
| void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, | |||||
| int ldb, float beta, float *mat_c, int ldc) { | |||||
| // printf("cpu: %d %d %d %d %d %f %d %d %f %d\n",TA, TB, M, N, K, ALPHA, lda, ldb, BETA, ldc); | |||||
| if (beta >= 0.f && beta <= 0.f) { | |||||
| for (int i = 0; i < M; ++i) { | |||||
| for (int j = 0; j < N; ++j) { | |||||
| mat_c[i * ldc + j] = 0; | |||||
| } | |||||
| } | |||||
| } else if (beta < 1.f || beta > 1.f) { | |||||
| for (int i = 0; i < M; ++i) { | |||||
| for (int j = 0; j < N; ++j) { | |||||
| mat_c[i * ldc + j] *= beta; | |||||
| } | |||||
| } | |||||
| } | |||||
| int t; | |||||
| for (t = 0; t < M; ++t) { | |||||
| if (!transpose_a && !transpose_b) { | |||||
| gemm_nn(1, N, K, alpha, mat_a + t * lda, lda, mat_b, ldb, mat_c + t * ldc, ldc); | |||||
| } else if (transpose_a && !transpose_b) { | |||||
| gemm_tn(1, N, K, alpha, mat_a + t, lda, mat_b, ldb, mat_c + t * ldc, ldc); | |||||
| } else if (!transpose_a && transpose_b) { | |||||
| gemm_nt(1, N, K, alpha, mat_a + t * lda, lda, mat_b, ldb, mat_c + t * ldc, ldc); | |||||
| } else { | |||||
| gemm_tt(1, N, K, alpha, mat_a + t, lda, mat_b, ldb, mat_c + t * ldc, ldc); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,23 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_GEMM_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_GEMM_H_ | |||||
| void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, | |||||
| int ldb, float beta, float *mat_c, int ldc); | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_GEMM_H_ | |||||
| @@ -0,0 +1,149 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <cstdint> | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/pooling_grad.h" | |||||
| void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param) { | |||||
| int stride_w = pooling_param->stride_w_; | |||||
| int stride_h = pooling_param->stride_h_; | |||||
| int pad_w = pooling_param->pad_l_; | |||||
| int pad_h = pooling_param->pad_u_; | |||||
| int win_w = pooling_param->window_w_; | |||||
| int win_h = pooling_param->window_h_; | |||||
| int channel = pooling_param->input_channel_; | |||||
| int in_w = pooling_param->input_w_; | |||||
| int in_h = pooling_param->input_h_; | |||||
| int output_w = pooling_param->output_w_; | |||||
| int output_h = pooling_param->output_h_; | |||||
| int output_batch = pooling_param->output_batch_; | |||||
| const float *inPtr; | |||||
| for (int i = 0; i < output_h * output_w * channel * output_batch; i++) output_ptr[i] = 0.0; | |||||
| // int pad_top = padding[2]; | |||||
| float kk = static_cast<float>(win_h * win_w); | |||||
| for (uint16_t ib = 0; ib < output_batch; ib++) { | |||||
| // int in_batch_offset = batch * in_h * in_w * channel; | |||||
| // int out_batch_offset = batch * output_h * output_w * channel; | |||||
| // out = grads->getData(ib*grads->imgSize()); | |||||
| // inPtr = in->getData(ib*in->imgSize()); | |||||
| float *out; | |||||
| out = &output_ptr[(ib * output_h * output_w)]; | |||||
| inPtr = reinterpret_cast<const float *>(&input_ptr[(ib * in_h * in_w)]); | |||||
| if (1) { // in->layout() == Tensor::nhwc) | |||||
| // iterate over yt | |||||
| for (uint16_t yh = 0; yh < in_h; yh++) { | |||||
| for (uint16_t yw = 0; yw < in_w; yw++) { | |||||
| for (uint16_t ic = 0; ic < channel; ic++) { | |||||
| int idx = (yw + yh * in_w) * channel + ic; // (ic*in_h*in_w) + (in_w*yh) + yw; | |||||
| float delta = inPtr[idx] / kk; | |||||
| for (int32_t kh = 0; kh < win_h; kh++) { | |||||
| int xh = yh * stride_h + kh - pad_h; | |||||
| if ((xh < 0) || (xh >= output_h)) { | |||||
| continue; | |||||
| } | |||||
| for (int32_t kw = 0; kw < win_w; kw++) { | |||||
| int xw = yw * stride_w + kw - pad_w; | |||||
| if ((xw < 0) || (xw >= output_w)) { | |||||
| continue; | |||||
| } | |||||
| // out[(ic*output_h*output_w) + (xh*output_w) + xw] += delta; | |||||
| out[(xw + output_w * xh) * channel + ic] += delta; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { // nchw | |||||
| for (uint16_t ic = 0; ic < channel; ic++) { | |||||
| // iterate over yt | |||||
| for (uint16_t yh = 0; yh < in_h; yh++) { | |||||
| for (uint16_t yw = 0; yw < in_w; yw++) { | |||||
| int idx = (ic * in_h * in_w) + (in_w * yh) + yw; | |||||
| float delta = inPtr[idx] / kk; | |||||
| for (int32_t kh = 0; kh < win_h; kh++) { | |||||
| int xh = yh * stride_h + kh - pad_h; | |||||
| if ((xh < 0) || (xh >= output_h)) { | |||||
| continue; | |||||
| } | |||||
| for (int32_t kw = 0; kw < win_w; kw++) { | |||||
| int xw = yw * stride_w + kw - pad_w; | |||||
| if ((xw < 0) || (xw >= output_w)) { | |||||
| continue; | |||||
| } | |||||
| out[(ic * output_h * output_w) + (xh * output_w) + xw] += delta; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void MaxPoolingGrad(const float *dy, const int *indices, float *output_ptr, PoolingParameter *pooling_param) { | |||||
| // int stride_w = pooling_param->stride_w_; | |||||
| // int stride_h = pooling_param->stride_h_; | |||||
| // int pad_w = pooling_param->pad_l_; | |||||
| // int pad_h = pooling_param->pad_u_; | |||||
| // int win_w = pooling_param->window_w_; | |||||
| // int win_h = pooling_param->window_h_; | |||||
| int channel = pooling_param->input_channel_; | |||||
| int in_w = pooling_param->input_w_; | |||||
| int in_h = pooling_param->input_h_; | |||||
| int output_w = pooling_param->output_w_; | |||||
| int output_h = pooling_param->output_h_; | |||||
| int output_batch = pooling_param->output_batch_; | |||||
| int out_img_size = | |||||
| output_h * output_w; // Emir -- in original code this varible is calculated according to input size ?? | |||||
| int ind_img_size = in_h * in_w; | |||||
| // const int w_pad = (output_w + pad_w + pad_w); | |||||
| for (int i = 0; i < output_h * output_w * channel * output_batch; i++) output_ptr[i] = 0.0; | |||||
| const float *yt = reinterpret_cast<const float *>(dy); | |||||
| const int *pos = reinterpret_cast<const int *>(indices); | |||||
| float *out; | |||||
| if (1) { // grads->layout() == Tensor::nhwc) | |||||
| for (int ib = 0; ib < output_batch; ib++) { | |||||
| out = &(output_ptr[ib * output_w * output_w * channel]); | |||||
| for (int ix = 0; ix < ind_img_size; ix++) { | |||||
| for (int cix = 0; cix < channel; cix++) { | |||||
| int idx = (*pos) * channel + cix; | |||||
| out[idx] += *yt; | |||||
| pos++; | |||||
| yt++; | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | |||||
| for (int ib = 0; ib < output_batch; ib++) { | |||||
| out = &output_ptr[(ib * out_img_size)]; | |||||
| for (int cix = 0; cix < channel; cix++) { | |||||
| for (int ix = 0; ix < ind_img_size; ix++) { | |||||
| int idx = cix * output_h * output_w + *pos; // cord_y*output_w + cord_x; | |||||
| out[idx] += *yt; | |||||
| pos++; | |||||
| yt++; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,25 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_POOLING_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_POOLING_GRAD_H_ | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/pooling.h" | |||||
| void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param); | |||||
| void MaxPoolingGrad(const float *dy, const int *indices_ptr, float *output_ptr, PoolingParameter *pooling_param); | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_POOLING_GRAD_H_ | |||||
| @@ -0,0 +1,130 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <cstddef> | |||||
| #include <algorithm> | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reduce_grad.h" | |||||
| static inline bool NextIndex(const int num_dims, const int *dims, int *current) { | |||||
| int carry = 1; | |||||
| for (int idx = num_dims - 1; idx >= 0; --idx) { | |||||
| int current_val = current[idx] + carry; | |||||
| if (dims[idx] == current_val) { | |||||
| current[idx] = 0; | |||||
| } else { | |||||
| current[idx] = current_val; | |||||
| carry = 0; | |||||
| break; | |||||
| } | |||||
| } | |||||
| return (carry == 0); | |||||
| } | |||||
| static inline size_t GetInputOffset(const int num_dims, const int *dims, const int *iter) { | |||||
| size_t offset = 0; | |||||
| for (int idx = 0; idx < num_dims; ++idx) { | |||||
| offset = offset * (size_t)(dims[idx]) + (size_t)(iter[idx]); | |||||
| } | |||||
| return offset; | |||||
| } | |||||
| static inline size_t GetOutputOffset(const int num_dims, const int *dims, const int *iter, const int num_axis, | |||||
| const int *axes) { | |||||
| size_t offset = 0; | |||||
| for (int idx = 0; idx < num_dims; ++idx) { | |||||
| // if we need to skip this axis | |||||
| bool is_axis = false; | |||||
| for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) { | |||||
| if (idx == axes[axis_idx]) { | |||||
| is_axis = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (!is_axis) { | |||||
| offset = offset * (size_t)(dims[idx]) + (size_t)(iter[idx]); | |||||
| } | |||||
| } | |||||
| return offset; | |||||
| } | |||||
| void ReduceMeanByAxes(const float *input_data, int *input_iter, const int *input_dims, int input_num_dims, | |||||
| const int *axes, int num_axes, float *output_data, const int *output_dims, int output_num_dims) { | |||||
| size_t num_outputs = 1; | |||||
| for (int idx = 0; idx < output_num_dims; ++idx) { | |||||
| size_t current = (size_t)(output_dims[idx]); | |||||
| num_outputs *= current; | |||||
| } | |||||
| // Reset input iterator. | |||||
| for (int idx = 0; idx < input_num_dims; ++idx) { | |||||
| input_iter[idx] = 0; | |||||
| } | |||||
| // Iterate through input_data. | |||||
| do { | |||||
| size_t input_offset = GetInputOffset(input_num_dims, input_dims, input_iter); | |||||
| size_t output_offset = GetOutputOffset(input_num_dims, input_dims, input_iter, num_axes, axes); | |||||
| output_data[output_offset] += input_data[input_offset]; | |||||
| } while (NextIndex(input_num_dims, input_dims, input_iter)); | |||||
| // Calculate mean by dividing output_data by num of aggregated element. | |||||
| size_t num_elements_in_axis = 1; | |||||
| for (int idx = 0; idx < num_axes; ++idx) { | |||||
| size_t current = (size_t)(input_dims[axes[idx]]); | |||||
| num_elements_in_axis *= current; | |||||
| } | |||||
| for (size_t idx = 0; idx < num_outputs; ++idx) { | |||||
| output_data[idx] = output_data[idx] / static_cast<float>(num_elements_in_axis); | |||||
| } | |||||
| } | |||||
| float ReduceMeanAll(const float *src, int size) { | |||||
| float sum = 0; | |||||
| for (int i = 0; i < size; ++i) { | |||||
| sum += src[i]; | |||||
| } | |||||
| return sum / size; | |||||
| } | |||||
| void ReduceSumByAxes(const float *input, const int *input_dims, float *output, const int *output_dims, int num_dims) { | |||||
| int num_outputs = 1; | |||||
| int same_shape = true; | |||||
| for (int idx = 0; idx < num_dims; ++idx) { | |||||
| num_outputs *= output_dims[idx]; | |||||
| if (output_dims[idx] != input_dims[idx]) same_shape = false; | |||||
| } | |||||
| if (same_shape) { | |||||
| std::copy(input, input + num_outputs * sizeof(float), output); | |||||
| // memcpy(output, input, num_outputs*sizeof(float)); | |||||
| return; | |||||
| } | |||||
| for (int idx = 0; idx < num_outputs; ++idx) output[idx] = 0; // zero output | |||||
| int input_iter[8] = {0}; | |||||
| int axes[5] = {0}; | |||||
| int num_axes = 0; | |||||
| for (int i = 0; i < num_dims; i++) | |||||
| if (output_dims[i] == 1) axes[num_axes++] = i; | |||||
| // Iterate through input_data. | |||||
| do { | |||||
| size_t input_offset = GetInputOffset(num_dims, input_dims, input_iter); | |||||
| size_t output_offset = GetOutputOffset(num_dims, input_dims, input_iter, num_axes, axes); | |||||
| output[output_offset] += input[input_offset]; | |||||
| } while (NextIndex(num_dims, input_dims, input_iter)); | |||||
| } | |||||
| @@ -0,0 +1,24 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_REDUCE_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_REDUCE_GRAD_H_ | |||||
| float ReduceMeanAll(const float *src, int size); | |||||
| void ReduceSumByAxes(const float *input, const int *input_dims, float *output, const int *output_dims, int num_dims); | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_REDUCE_GRAD_H_ | |||||
| @@ -0,0 +1,29 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_SOFTMAX_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_SOFTMAX_GRAD_H_ | |||||
| #include "src/runtime/kernel/arm/opclib/op_base.h" | |||||
| struct SoftmaxCrossEntropyParameter { | |||||
| OpParameter op_parameter; | |||||
| int32_t batch_size_; | |||||
| unsigned int number_of_classes_; | |||||
| int n_dim_; | |||||
| int input_shape_[5]; | |||||
| }; | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_SOFTMAX_GRAD_H_ | |||||
| @@ -0,0 +1,176 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <string.h> | |||||
| #include "src/runtime/kernel/arm/opclib/pack_ext.h" | |||||
| static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); } | |||||
| void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param) { | |||||
| const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_w_; | |||||
| // const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_; | |||||
| const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_h_; | |||||
| // const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_; | |||||
| const int stride_h = conv_param->stride_h_; | |||||
| const int stride_w = conv_param->stride_w_; | |||||
| const int dilation_h = conv_param->dilation_h_; | |||||
| const int dilation_w = conv_param->dilation_w_; | |||||
| const int kernel_h = conv_param->kernel_h_; | |||||
| const int kernel_w = conv_param->kernel_w_; | |||||
| const int in_height = conv_param->input_h_; | |||||
| const int in_width = conv_param->input_w_; | |||||
| const int output_h = conv_param->output_h_; | |||||
| const int output_w = conv_param->output_w_; | |||||
| const int channels = conv_param->input_channel_ / conv_param->group_; | |||||
| const int tot_channels = conv_param->input_channel_; | |||||
| int /*channel,*/ kernel_row, kernel_col, output_rows, output_col; | |||||
| int row_stride_offset = 0; | |||||
| for (output_rows = output_h; output_rows; output_rows--) { | |||||
| int col_stride_offset = 0; | |||||
| for (output_col = output_w; output_col; output_col--) { | |||||
| for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { | |||||
| int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; | |||||
| for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | |||||
| int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; | |||||
| if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) { | |||||
| const int offset = (input_row * in_width + input_col) * tot_channels; | |||||
| memcpy(data_col, in_data + offset, sizeof(float) * channels); | |||||
| data_col += channels; | |||||
| } else { | |||||
| memset(data_col, 0, sizeof(float) * channels); | |||||
| data_col += channels; | |||||
| } | |||||
| } | |||||
| } | |||||
| col_stride_offset += stride_w; | |||||
| } | |||||
| row_stride_offset += stride_h; | |||||
| } | |||||
| } | |||||
| // output matrix is (kernel_h*kernel_w*channels)X(output_h*output_w) | |||||
| void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param) { | |||||
| const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_w_; | |||||
| // const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_; | |||||
| const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_h_; | |||||
| // const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_; | |||||
| const int stride_h = conv_param->stride_h_; | |||||
| const int stride_w = conv_param->stride_w_; | |||||
| const int dilation_h = conv_param->dilation_h_; | |||||
| const int dilation_w = conv_param->dilation_w_; | |||||
| const int kernel_h = conv_param->kernel_h_; | |||||
| const int kernel_w = conv_param->kernel_w_; | |||||
| const int in_height = conv_param->input_h_; | |||||
| const int in_width = conv_param->input_w_; | |||||
| const int output_h = conv_param->output_h_; | |||||
| const int output_w = conv_param->output_w_; | |||||
| const int channels = conv_param->input_channel_ / conv_param->group_; | |||||
| const int tot_channels = conv_param->input_channel_; | |||||
| int channel, kernel_row, kernel_col, output_rows, output_col; | |||||
| for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { | |||||
| for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | |||||
| for (channel = 0; channel < channels; channel++) { | |||||
| int input_row = -pad_up + kernel_row * dilation_h; | |||||
| for (output_rows = output_h; output_rows; output_rows--) { | |||||
| if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) { | |||||
| for (output_col = output_w; output_col; output_col--) { | |||||
| *(data_row++) = 0; | |||||
| } | |||||
| } else { | |||||
| int input_col = -pad_left + kernel_col * dilation_w; | |||||
| for (output_col = output_w; output_col; output_col--) { | |||||
| if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) { | |||||
| const int offset = (input_row * in_width + input_col) * tot_channels + channel; | |||||
| *(data_row++) = in_data[offset]; | |||||
| } else { | |||||
| *(data_row++) = 0; | |||||
| } | |||||
| input_col += stride_w; | |||||
| } | |||||
| } | |||||
| input_row += stride_h; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void col2im_hwc(const float *data_col, float *data_im, ConvParameter *conv_param) { | |||||
| const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_w_; | |||||
| // const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_; | |||||
| const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_h_; | |||||
| // const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_; | |||||
| const int stride_h = conv_param->stride_h_; | |||||
| const int stride_w = conv_param->stride_w_; | |||||
| const int dilation_h = conv_param->dilation_h_; | |||||
| const int dilation_w = conv_param->dilation_w_; | |||||
| const int kernel_h = conv_param->kernel_h_; | |||||
| const int kernel_w = conv_param->kernel_w_; | |||||
| const int in_height = conv_param->input_h_; | |||||
| const int in_width = conv_param->input_w_; | |||||
| const int output_h = conv_param->output_h_; | |||||
| const int output_w = conv_param->output_w_; | |||||
| const int channels = conv_param->input_channel_ / conv_param->group_; | |||||
| const int tot_channels = conv_param->input_channel_; | |||||
| int kernel_row, kernel_col, output_rows, output_col; | |||||
| int row_stride_offset = 0; | |||||
| for (output_rows = output_h; output_rows; output_rows--) { | |||||
| int col_stride_offset = 0; | |||||
| for (output_col = output_w; output_col; output_col--) { | |||||
| for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { | |||||
| int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; | |||||
| for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | |||||
| int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; | |||||
| if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) { | |||||
| int offset = (input_row * in_width + input_col) * tot_channels; | |||||
| float *data_im_ptr = &data_im[offset]; | |||||
| for (int i = 0; i < channels; i++) { | |||||
| data_im_ptr[i] += data_col[i]; | |||||
| } | |||||
| } | |||||
| data_col += channels; | |||||
| } | |||||
| } | |||||
| col_stride_offset += stride_w; | |||||
| } | |||||
| row_stride_offset += stride_h; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PACK_EXT_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PACK_EXT_H_ | |||||
| #include "src/runtime/kernel/arm/opclib/conv_parameter.h" | |||||
| void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param); | |||||
| void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param); | |||||
| void col2im_hwc(const float *data_col, float *data_im, ConvParameter *conv_param); | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PACK_EXT_H | |||||
| @@ -152,6 +152,7 @@ set(TEST_LITE_SRC | |||||
| ${LITE_DIR}/src/scheduler.cc | ${LITE_DIR}/src/scheduler.cc | ||||
| ${LITE_DIR}/src/common/graph_util.cc | ${LITE_DIR}/src/common/graph_util.cc | ||||
| ${LITE_DIR}/src/common/file_utils.cc | ${LITE_DIR}/src/common/file_utils.cc | ||||
| ${LITE_DIR}/src/common/file_utils_ext.cc | |||||
| ${LITE_DIR}/src/common/utils.cc | ${LITE_DIR}/src/common/utils.cc | ||||
| ${LITE_DIR}/tools/common/graph_util.cc | ${LITE_DIR}/tools/common/graph_util.cc | ||||
| ${LITE_DIR}/tools/common/tensor_util.cc | ${LITE_DIR}/tools/common/tensor_util.cc | ||||
| @@ -0,0 +1,312 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "utils/log_adapter.h" | |||||
| #include "common/common_test.h" | |||||
| #include "src/common/file_utils.h" | |||||
| #include "src/common/file_utils_ext.h" | |||||
| #include "mindspore/lite/src/kernel_registry.h" | |||||
| #include "mindspore/lite/src/ir/tensor.h" | |||||
| #include "mindspore/lite/src/lite_kernel.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.h" | |||||
| namespace mindspore { | |||||
| class TestActGradFp32 : public mindspore::Common { | |||||
| public: | |||||
| TestActGradFp32() {} | |||||
| }; | |||||
| TEST_F(TestActGradFp32, ReluGradFp32) { | |||||
| // runtime part | |||||
| printf("Calculating runtime cost...\n"); | |||||
| uint64_t time_avg = 0; | |||||
| size_t output_data_size = 50; | |||||
| size_t input_size; | |||||
| std::string input_path = "./test_data/activationGrad/relu_y_50.bin"; | |||||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||||
| std::string yt_path = "./test_data/activationGrad/relu_yt_50.bin"; | |||||
| auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); | |||||
| auto output_data = new float[output_data_size]; | |||||
| // warm up loop | |||||
| for (int i = 0; i < 3; i++) { | |||||
| ReluGrad(yt_data, input_data, 50, output_data); | |||||
| } | |||||
| int loop_count = 100; | |||||
| auto time_start = mindspore::lite::GetTimeUs(); | |||||
| for (int i = 0; i < loop_count; i++) { | |||||
| ReluGrad(yt_data, input_data, 50, output_data); | |||||
| } | |||||
| auto time_end = mindspore::lite::GetTimeUs(); | |||||
| auto cost = time_end - time_start; | |||||
| time_avg = cost / loop_count; | |||||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 20; i++) { | |||||
| std::cout << output_data[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/activationGrad/relu_out_50.bin"; | |||||
| int res = lite::CompareRelativeOutput(output_data, output_path); | |||||
| EXPECT_EQ(res, 0); | |||||
| delete input_data; | |||||
| delete[] output_data; | |||||
| delete yt_data; | |||||
| MS_LOG(INFO) << "ReluGradFp32 passed"; | |||||
| } | |||||
| TEST_F(TestActGradFp32, Relu6GradFp32) { | |||||
| // runtime part | |||||
| printf("Calculating runtime cost...\n"); | |||||
| uint64_t time_avg = 0; | |||||
| size_t output_data_size = 50; | |||||
| size_t input_size; | |||||
| std::string input_path = "./test_data/activationGrad/relu6_y_50.bin"; | |||||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||||
| std::string yt_path = "./test_data/activationGrad/relu6_yt_50.bin"; | |||||
| auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); | |||||
| auto output_data = new float[output_data_size]; | |||||
| // warm up loop | |||||
| for (int i = 0; i < 3; i++) { | |||||
| Relu6Grad(yt_data, input_data, 50, output_data); | |||||
| } | |||||
| int loop_count = 100; | |||||
| auto time_start = mindspore::lite::GetTimeUs(); | |||||
| for (int i = 0; i < loop_count; i++) { | |||||
| Relu6Grad(yt_data, input_data, 50, output_data); | |||||
| } | |||||
| auto time_end = mindspore::lite::GetTimeUs(); | |||||
| auto cost = time_end - time_start; | |||||
| time_avg = cost / loop_count; | |||||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 20; i++) { | |||||
| std::cout << output_data[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/activationGrad/relu6_out_50.bin"; | |||||
| int res = lite::CompareRelativeOutput(output_data, output_path); | |||||
| EXPECT_EQ(res, 0); | |||||
| delete input_data; | |||||
| delete[] output_data; | |||||
| delete yt_data; | |||||
| MS_LOG(INFO) << "Relu6GradFp32 passed"; | |||||
| } | |||||
| TEST_F(TestActGradFp32, LReluGradFp32) { | |||||
| // runtime part | |||||
| printf("Calculating runtime cost...\n"); | |||||
| uint64_t time_avg = 0; | |||||
| size_t output_data_size = 50; | |||||
| size_t input_size; | |||||
| std::string input_path = "./test_data/activationGrad/lrelu_y_50.bin"; | |||||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||||
| std::string yt_path = "./test_data/activationGrad/lrelu_yt_50.bin"; | |||||
| auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); | |||||
| auto output_data = new float[output_data_size]; | |||||
| // warm up loop | |||||
| for (int i = 0; i < 3; i++) { | |||||
| LReluGrad(yt_data, input_data, 50, output_data, 0.1); | |||||
| } | |||||
| int loop_count = 100; | |||||
| auto time_start = mindspore::lite::GetTimeUs(); | |||||
| for (int i = 0; i < loop_count; i++) { | |||||
| LReluGrad(yt_data, input_data, 50, output_data, 0.1); | |||||
| } | |||||
| auto time_end = mindspore::lite::GetTimeUs(); | |||||
| auto cost = time_end - time_start; | |||||
| time_avg = cost / loop_count; | |||||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 20; i++) { | |||||
| std::cout << output_data[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/activationGrad/lrelu_out_50.bin"; | |||||
| int res = lite::CompareRelativeOutput(output_data, output_path); | |||||
| EXPECT_EQ(res, 0); | |||||
| delete input_data; | |||||
| delete[] output_data; | |||||
| delete yt_data; | |||||
| MS_LOG(INFO) << "LReluGradFp32 passed"; | |||||
| } | |||||
| TEST_F(TestActGradFp32, SigmoidGradFp32) { | |||||
| // runtime part | |||||
| printf("Calculating runtime cost...\n"); | |||||
| uint64_t time_avg = 0; | |||||
| size_t output_data_size = 50; | |||||
| size_t input_size; | |||||
| std::string input_path = "./test_data/activationGrad/sigmoid_y_50.bin"; | |||||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||||
| std::string yt_path = "./test_data/activationGrad/sigmoid_yt_50.bin"; | |||||
| auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); | |||||
| auto output_data = new float[output_data_size]; | |||||
| // warm up loop | |||||
| for (int i = 0; i < 3; i++) { | |||||
| SigmoidGrad(yt_data, input_data, 50, output_data); | |||||
| } | |||||
| int loop_count = 100; | |||||
| auto time_start = mindspore::lite::GetTimeUs(); | |||||
| for (int i = 0; i < loop_count; i++) { | |||||
| SigmoidGrad(yt_data, input_data, 50, output_data); | |||||
| } | |||||
| auto time_end = mindspore::lite::GetTimeUs(); | |||||
| auto cost = time_end - time_start; | |||||
| time_avg = cost / loop_count; | |||||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 20; i++) { | |||||
| std::cout << output_data[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/activationGrad/sigmoid_out_50.bin"; | |||||
| int res = lite::CompareRelativeOutput(output_data, output_path); | |||||
| EXPECT_EQ(res, 0); | |||||
| // lite::CompareOutput(output_data, output_path); | |||||
| delete input_data; | |||||
| delete[] output_data; | |||||
| delete yt_data; | |||||
| MS_LOG(INFO) << "SigmoidGradFp32 passed"; | |||||
| } | |||||
| TEST_F(TestActGradFp32, tanhGradFp32) { | |||||
| // runtime part | |||||
| printf("Calculating runtime cost...\n"); | |||||
| uint64_t time_avg = 0; | |||||
| size_t output_data_size = 50; | |||||
| size_t input_size; | |||||
| std::string input_path = "./test_data/activationGrad/tanh_y_50.bin"; | |||||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||||
| std::string yt_path = "./test_data/activationGrad/tanh_yt_50.bin"; | |||||
| auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); | |||||
| auto output_data = new float[output_data_size]; | |||||
| // warm up loop | |||||
| for (int i = 0; i < 3; i++) { | |||||
| TanhGrad(yt_data, input_data, 50, output_data); | |||||
| } | |||||
| int loop_count = 100; | |||||
| auto time_start = mindspore::lite::GetTimeUs(); | |||||
| for (int i = 0; i < loop_count; i++) { | |||||
| TanhGrad(yt_data, input_data, 50, output_data); | |||||
| } | |||||
| auto time_end = mindspore::lite::GetTimeUs(); | |||||
| auto cost = time_end - time_start; | |||||
| time_avg = cost / loop_count; | |||||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 20; i++) { | |||||
| std::cout << output_data[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/activationGrad/tanh_out_50.bin"; | |||||
| int res = lite::CompareRelativeOutput(output_data, output_path); | |||||
| EXPECT_EQ(res, 0); | |||||
| delete input_data; | |||||
| delete[] output_data; | |||||
| delete yt_data; | |||||
| MS_LOG(INFO) << "TanhGradFp32 passed"; | |||||
| } | |||||
| TEST_F(TestActGradFp32, hswishGradFp32) { | |||||
| // runtime part | |||||
| printf("Calculating runtime cost...\n"); | |||||
| uint64_t time_avg = 0; | |||||
| size_t output_data_size = 50; | |||||
| size_t input_size; | |||||
| std::string input_path = "./test_data/activationGrad/hswish_x_50.bin"; | |||||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||||
| std::string yt_path = "./test_data/activationGrad/hswish_yt_50.bin"; | |||||
| auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); | |||||
| auto output_data = new float[output_data_size]; | |||||
| // warm up loop | |||||
| for (int i = 0; i < 3; i++) { | |||||
| HSwishGrad(yt_data, input_data, 50, output_data); | |||||
| } | |||||
| int loop_count = 100; | |||||
| auto time_start = mindspore::lite::GetTimeUs(); | |||||
| for (int i = 0; i < loop_count; i++) { | |||||
| HSwishGrad(yt_data, input_data, 50, output_data); | |||||
| } | |||||
| auto time_end = mindspore::lite::GetTimeUs(); | |||||
| auto cost = time_end - time_start; | |||||
| time_avg = cost / loop_count; | |||||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 20; i++) { | |||||
| std::cout << output_data[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/activationGrad/hswish_out_50.bin"; | |||||
| int res = lite::CompareRelativeOutput(output_data, output_path); | |||||
| EXPECT_EQ(res, 0); | |||||
| delete input_data; | |||||
| delete[] output_data; | |||||
| delete yt_data; | |||||
| MS_LOG(INFO) << "hswishGradFp32 passed"; | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,497 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "utils/log_adapter.h" | |||||
| #include "common/common_test.h" | |||||
| #include "src/common/file_utils.h" | |||||
| #include "src/common/file_utils_ext.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reduce.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.h" | |||||
| #include "mindspore/lite/src/kernel_registry.h" | |||||
| namespace mindspore { | |||||
| class TestArithmeticGradFp32 : public mindspore::Common { | |||||
| public: | |||||
| TestArithmeticGradFp32() {} | |||||
| }; | |||||
| std::vector<lite::tensor::Tensor *> GenerateTensorsForTest(const char *test, int test_id) { | |||||
| size_t input_size; | |||||
| std::vector<int> large_dim({4, 6}); | |||||
| std::vector<int> small_dim({6}); | |||||
| int large_size = (4 * 6); | |||||
| int small_size = (1 * 6); | |||||
| char *dx1_file = const_cast<char *>("./test_data/operators/arithmetic_fp32_1_x1_4_6.bin"); | |||||
| char *dx2_file = const_cast<char *>("./test_data/operators/arithmetic_fp32_1_x2_1_6.bin"); | |||||
| if (test_id == 7) { | |||||
| large_dim = std::vector<int>({4, 5, 6}); | |||||
| small_dim = std::vector<int>({6}); | |||||
| large_size = (4 * 5 * 6); | |||||
| small_size = (6); | |||||
| dx1_file = const_cast<char *>("./test_data/operators/arithmetic_fp32_7_x1_4_5_6.bin"); | |||||
| dx2_file = const_cast<char *>("./test_data/operators/arithmetic_fp32_7_x2_1_1_6.bin"); | |||||
| } | |||||
| if (test_id >= 8) { | |||||
| large_dim = std::vector<int>({5, 4, 6}); | |||||
| small_dim = std::vector<int>({5, 1, 6}); | |||||
| large_size = (4 * 5 * 6); | |||||
| small_size = (5 * 6); | |||||
| dx1_file = const_cast<char *>("./test_data/operators/arithmetic_fp32_8_x1_5_4_6.bin"); | |||||
| dx2_file = const_cast<char *>("./test_data/operators/arithmetic_fp32_8_x2_5_1_6.bin"); | |||||
| } | |||||
| auto dy_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(test, &input_size)); | |||||
| lite::tensor::Tensor *dy_tensor = new lite::tensor::Tensor(TypeId::kNumberTypeFloat32, large_dim); | |||||
| dy_tensor->SetData(dy_data); | |||||
| auto x1_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dx1_file, &input_size)); | |||||
| lite::tensor::Tensor *x1_tensor = new lite::tensor::Tensor(TypeId::kNumberTypeFloat32, large_dim); | |||||
| x1_tensor->SetData(x1_data); | |||||
| auto x2_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dx2_file, &input_size)); | |||||
| lite::tensor::Tensor *x2_tensor = new lite::tensor::Tensor(TypeId::kNumberTypeFloat32, small_dim); | |||||
| x2_tensor->SetData(x2_data); | |||||
| auto dx1_data = new float[large_size]; | |||||
| lite::tensor::Tensor *dx1_tensor = new lite::tensor::Tensor(TypeId::kNumberTypeFloat32, large_dim); | |||||
| dx1_tensor->SetData(dx1_data); | |||||
| auto dx2_data = new float[small_size]; | |||||
| lite::tensor::Tensor *dx2_tensor = new lite::tensor::Tensor(TypeId::kNumberTypeFloat32, small_dim); | |||||
| dx2_tensor->SetData(dx2_data); | |||||
| std::vector<lite::tensor::Tensor *> ret_vector = {dy_tensor, x1_tensor, x2_tensor, dx1_tensor, dx2_tensor}; | |||||
| return ret_vector; | |||||
| } | |||||
| TEST_F(TestArithmeticGradFp32, TestAddGradFp32) { | |||||
| auto param = new ArithmeticParameter(); | |||||
| param->op_parameter_.type_ = PrimitiveType_AddGrad; | |||||
| std::vector<lite::tensor::Tensor *> all_tensors = | |||||
| GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_1_dy_4_6.bin", 1); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {all_tensors[0], all_tensors[1], all_tensors[2]}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {all_tensors[3], all_tensors[4]}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_AddGrad}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc); | |||||
| kernel_obj->Run(); | |||||
| float *output_ptr = reinterpret_cast<float *>(outputs[1]->Data()); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 6; i++) { | |||||
| std::cout << output_ptr[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/operators/arithmetic_fp32_1_dx1_4_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast<float *>(outputs[0]->Data()), output_path)); | |||||
| std::string dx2_path = "./test_data/operators/arithmetic_fp32_1_dx2_1_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); | |||||
| for (int i = 0; i < 5; i++) delete all_tensors[i]; | |||||
| delete param; | |||||
| MS_LOG(INFO) << "TestAddGradFp32 passed"; | |||||
| } | |||||
| TEST_F(TestArithmeticGradFp32, TestAddGrad2Fp32) { | |||||
| auto param = new ArithmeticParameter(); | |||||
| param->op_parameter_.type_ = PrimitiveType_AddGrad; | |||||
| std::vector<lite::tensor::Tensor *> all_tensors = | |||||
| GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_1_dy_4_6.bin", 1); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {all_tensors[0], all_tensors[2], all_tensors[1]}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {all_tensors[4], all_tensors[3]}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_AddGrad}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc); | |||||
| kernel_obj->Run(); | |||||
| float *output_ptr = reinterpret_cast<float *>(outputs[0]->Data()); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 6; i++) { | |||||
| std::cout << output_ptr[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/operators/arithmetic_fp32_1_dx1_4_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast<float *>(outputs[1]->Data()), output_path)); | |||||
| std::string dx2_path = "./test_data/operators/arithmetic_fp32_1_dx2_1_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); | |||||
| for (int i = 0; i < 5; i++) delete all_tensors[i]; | |||||
| delete param; | |||||
| MS_LOG(INFO) << "TestAddGrad2Fp32 passed"; | |||||
| } | |||||
| TEST_F(TestArithmeticGradFp32, TestAddGrad3Fp32) { | |||||
| auto param = new ArithmeticParameter(); | |||||
| param->op_parameter_.type_ = PrimitiveType_AddGrad; | |||||
| std::vector<lite::tensor::Tensor *> all_tensors = | |||||
| GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_8_dy_5_4_6.bin", 8); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {all_tensors[0], all_tensors[1], all_tensors[2]}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {all_tensors[3], all_tensors[4]}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_AddGrad}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc); | |||||
| kernel_obj->Run(); | |||||
| float *output_ptr = reinterpret_cast<float *>(outputs[0]->Data()); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 6; i++) { | |||||
| std::cout << output_ptr[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/operators/arithmetic_fp32_8_dx2_5_1_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast<float *>(outputs[1]->Data()), output_path)); | |||||
| std::string dx2_path = "./test_data/operators/arithmetic_fp32_8_dx1_5_4_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); | |||||
| for (int i = 0; i < 5; i++) delete all_tensors[i]; | |||||
| delete param; | |||||
| MS_LOG(INFO) << "TestAddGrad3Fp32 passed"; | |||||
| } | |||||
| TEST_F(TestArithmeticGradFp32, TestSubGradFp32) { | |||||
| auto param = new ArithmeticParameter(); | |||||
| param->op_parameter_.type_ = PrimitiveType_SubGrad; | |||||
| std::vector<lite::tensor::Tensor *> all_tensors = | |||||
| GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_2_dy_4_6.bin", 2); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {all_tensors[0], all_tensors[1], all_tensors[2]}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {all_tensors[3], all_tensors[4]}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SubGrad}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc); | |||||
| kernel_obj->Run(); | |||||
| float *output_ptr = reinterpret_cast<float *>(outputs[1]->Data()); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 6; i++) { | |||||
| std::cout << output_ptr[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/operators/arithmetic_fp32_2_dx1_4_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast<float *>(outputs[0]->Data()), output_path)); | |||||
| std::string dx2_path = "./test_data/operators/arithmetic_fp32_2_dx2_1_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); | |||||
| for (int i = 0; i < 5; i++) delete all_tensors[i]; | |||||
| delete param; | |||||
| MS_LOG(INFO) << "TestSubGradFp32 passed"; | |||||
| } | |||||
| TEST_F(TestArithmeticGradFp32, TestSubGrad2Fp32) { | |||||
| auto param = new ArithmeticParameter(); | |||||
| param->op_parameter_.type_ = PrimitiveType_SubGrad; | |||||
| std::vector<lite::tensor::Tensor *> all_tensors = | |||||
| GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_3_dy_4_6.bin", 3); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {all_tensors[0], all_tensors[2], all_tensors[1]}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {all_tensors[4], all_tensors[3]}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SubGrad}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc); | |||||
| kernel_obj->Run(); | |||||
| float *output_ptr = reinterpret_cast<float *>(outputs[0]->Data()); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 6; i++) { | |||||
| std::cout << output_ptr[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/operators/arithmetic_fp32_3_dx1_4_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast<float *>(outputs[1]->Data()), output_path)); | |||||
| std::string dx2_path = "./test_data/operators/arithmetic_fp32_3_dx2_1_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); | |||||
| for (int i = 0; i < 5; i++) delete all_tensors[i]; | |||||
| delete param; | |||||
| MS_LOG(INFO) << "TestSubGrad2Fp32 passed"; | |||||
| } | |||||
| TEST_F(TestArithmeticGradFp32, TestMulGradFp32) { | |||||
| auto param = new ArithmeticParameter(); | |||||
| param->op_parameter_.type_ = PrimitiveType_MulGrad; | |||||
| std::vector<lite::tensor::Tensor *> all_tensors = | |||||
| GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_4_dy_4_6.bin", 4); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {all_tensors[0], all_tensors[1], all_tensors[2]}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {all_tensors[3], all_tensors[4]}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_MulGrad}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc); | |||||
| int loop_count = 1000; | |||||
| auto time_start = mindspore::lite::GetTimeUs(); | |||||
| for (int i = 0; i < loop_count; i++) { | |||||
| kernel_obj->Run(); | |||||
| } | |||||
| auto time_end = mindspore::lite::GetTimeUs(); | |||||
| auto cost = time_end - time_start; | |||||
| printf("total cost (for %d loops): %lu us\n", loop_count, cost); | |||||
| // auto time_avg = cost / loop_count; | |||||
| // printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||||
| float *output_ptr = reinterpret_cast<float *>(outputs[1]->Data()); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 6; i++) { | |||||
| std::cout << output_ptr[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/operators/arithmetic_fp32_4_dx1_4_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast<float *>(outputs[0]->Data()), output_path)); | |||||
| std::string dx2_path = "./test_data/operators/arithmetic_fp32_4_dx2_1_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); | |||||
| for (int i = 0; i < 5; i++) delete all_tensors[i]; | |||||
| delete param; | |||||
| MS_LOG(INFO) << "TestMulGradFp32 passed"; | |||||
| } | |||||
| TEST_F(TestArithmeticGradFp32, TestMulGrad2Fp32) { | |||||
| auto param = new ArithmeticParameter(); | |||||
| param->op_parameter_.type_ = PrimitiveType_MulGrad; | |||||
| std::vector<lite::tensor::Tensor *> all_tensors = | |||||
| GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_4_dy_4_6.bin", 4); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {all_tensors[0], all_tensors[2], all_tensors[1]}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {all_tensors[4], all_tensors[3]}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_MulGrad}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc); | |||||
| kernel_obj->Run(); | |||||
| float *output_ptr = reinterpret_cast<float *>(outputs[0]->Data()); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 6; i++) { | |||||
| std::cout << output_ptr[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/operators/arithmetic_fp32_4_dx1_4_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast<float *>(outputs[1]->Data()), output_path)); | |||||
| std::string dx2_path = "./test_data/operators/arithmetic_fp32_4_dx2_1_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); | |||||
| for (int i = 0; i < 5; i++) delete all_tensors[i]; | |||||
| delete param; | |||||
| MS_LOG(INFO) << "TestMulGrad2Fp32 passed"; | |||||
| } | |||||
| TEST_F(TestArithmeticGradFp32, TestMulGrad3Fp32) { | |||||
| auto param = new ArithmeticParameter(); | |||||
| param->op_parameter_.type_ = PrimitiveType_MulGrad; | |||||
| std::vector<lite::tensor::Tensor *> all_tensors = | |||||
| GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_9_dy_5_4_6.bin", 9); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {all_tensors[0], all_tensors[1], all_tensors[2]}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {all_tensors[3], all_tensors[4]}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_MulGrad}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc); | |||||
| kernel_obj->Run(); | |||||
| float *output_ptr = reinterpret_cast<float *>(outputs[1]->Data()); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 6; i++) { | |||||
| std::cout << output_ptr[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/operators/arithmetic_fp32_9_dx1_5_4_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast<float *>(outputs[0]->Data()), output_path)); | |||||
| std::string dx2_path = "./test_data/operators/arithmetic_fp32_9_dx2_5_1_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); | |||||
| for (int i = 0; i < 5; i++) delete all_tensors[i]; | |||||
| delete param; | |||||
| MS_LOG(INFO) << "TestMulGrad3Fp32 passed"; | |||||
| } | |||||
| TEST_F(TestArithmeticGradFp32, TestMulGrad4Fp32) { | |||||
| auto param = new ArithmeticParameter(); | |||||
| param->op_parameter_.type_ = PrimitiveType_MulGrad; | |||||
| std::vector<lite::tensor::Tensor *> all_tensors = | |||||
| GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_9_dy_5_4_6.bin", 9); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {all_tensors[0], all_tensors[2], all_tensors[1]}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {all_tensors[4], all_tensors[3]}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_MulGrad}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc); | |||||
| kernel_obj->Run(); | |||||
| float *output_ptr = reinterpret_cast<float *>(outputs[0]->Data()); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 6; i++) { | |||||
| std::cout << output_ptr[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/operators/arithmetic_fp32_9_dx1_5_4_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast<float *>(outputs[1]->Data()), output_path)); | |||||
| std::string dx2_path = "./test_data/operators/arithmetic_fp32_9_dx2_5_1_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); | |||||
| for (int i = 0; i < 5; i++) delete all_tensors[i]; | |||||
| delete param; | |||||
| MS_LOG(INFO) << "TestMulGrad4Fp32 passed"; | |||||
| } | |||||
| TEST_F(TestArithmeticGradFp32, TestDivGradFp32) { | |||||
| auto param = new ArithmeticParameter(); | |||||
| param->op_parameter_.type_ = PrimitiveType_DivGrad; | |||||
| std::vector<lite::tensor::Tensor *> all_tensors = | |||||
| GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_5_dy_4_6.bin", 5); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {all_tensors[0], all_tensors[1], all_tensors[2]}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {all_tensors[3], all_tensors[4]}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DivGrad}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc); | |||||
| kernel_obj->Run(); | |||||
| float *output_ptr = reinterpret_cast<float *>(outputs[1]->Data()); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 6; i++) { | |||||
| std::cout << output_ptr[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/operators/arithmetic_fp32_5_dx1_4_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast<float *>(outputs[0]->Data()), output_path)); | |||||
| std::string dx2_path = "./test_data/operators/arithmetic_fp32_5_dx2_1_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); | |||||
| for (int i = 0; i < 5; i++) delete all_tensors[i]; | |||||
| delete param; | |||||
| MS_LOG(INFO) << "TestDivGradFp32 passed"; | |||||
| } | |||||
| TEST_F(TestArithmeticGradFp32, TestDivGrad2Fp32) { | |||||
| auto param = new ArithmeticParameter(); | |||||
| param->op_parameter_.type_ = PrimitiveType_DivGrad; | |||||
| std::vector<lite::tensor::Tensor *> all_tensors = | |||||
| GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_6_dy_4_6.bin", 6); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {all_tensors[0], all_tensors[2], all_tensors[1]}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {all_tensors[4], all_tensors[3]}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DivGrad}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc); | |||||
| kernel_obj->Run(); | |||||
| float *output_ptr = reinterpret_cast<float *>(outputs[0]->Data()); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 6; i++) { | |||||
| std::cout << output_ptr[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string dx2_path = "./test_data/operators/arithmetic_fp32_6_dx2_4_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast<float *>(outputs[1]->Data()), dx2_path)); | |||||
| std::string output_path = "./test_data/operators/arithmetic_fp32_6_dx1_1_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, output_path)); | |||||
| for (int i = 0; i < 5; i++) delete all_tensors[i]; | |||||
| delete param; | |||||
| MS_LOG(INFO) << "TestDivGrad2Fp32 passed"; | |||||
| } | |||||
| TEST_F(TestArithmeticGradFp32, TestDivGrad3Fp32) { | |||||
| auto param = new ArithmeticParameter(); | |||||
| param->op_parameter_.type_ = PrimitiveType_DivGrad; | |||||
| std::vector<lite::tensor::Tensor *> all_tensors = | |||||
| GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_10_dy_5_4_6.bin", 10); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {all_tensors[0], all_tensors[1], all_tensors[2]}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {all_tensors[3], all_tensors[4]}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DivGrad}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc); | |||||
| kernel_obj->Run(); | |||||
| float *output_ptr = reinterpret_cast<float *>(outputs[1]->Data()); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 6; i++) { | |||||
| std::cout << output_ptr[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string dx1_path = "./test_data/operators/arithmetic_fp32_10_dx1_5_4_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast<float *>(outputs[0]->Data()), dx1_path)); | |||||
| std::string output_path = "./test_data/operators/arithmetic_fp32_10_dx2_5_1_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, output_path)); | |||||
| for (int i = 0; i < 5; i++) delete all_tensors[i]; | |||||
| delete param; | |||||
| MS_LOG(INFO) << "TestDivGrad3Fp32 passed"; | |||||
| } | |||||
| TEST_F(TestArithmeticGradFp32, Test3DDivGrad2Fp32) { | |||||
| auto param = new ArithmeticParameter(); | |||||
| param->op_parameter_.type_ = PrimitiveType_DivGrad; | |||||
| std::vector<lite::tensor::Tensor *> all_tensors = | |||||
| GenerateTensorsForTest("./test_data/operators/arithmetic_fp32_7_dy_4_5_6.bin", 7); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {all_tensors[0], all_tensors[1], all_tensors[2]}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {all_tensors[3], all_tensors[4]}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DivGrad}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc); | |||||
| kernel_obj->Run(); | |||||
| float *output_ptr = reinterpret_cast<float *>(outputs[1]->Data()); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 6; i++) { | |||||
| std::cout << output_ptr[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string dx1_path = "./test_data/operators/arithmetic_fp32_7_dx1_4_5_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(reinterpret_cast<float *>(outputs[0]->Data()), dx1_path)); | |||||
| std::string output_path = "./test_data/operators/arithmetic_fp32_7_dx2_1_1_6.bin"; | |||||
| EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, output_path)); | |||||
| for (int i = 0; i < 5; i++) delete all_tensors[i]; | |||||
| delete param; | |||||
| MS_LOG(INFO) << "TestDivGrad2Fp32 passed"; | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,71 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include "utils/log_adapter.h" | |||||
| #include "common/common_test.h" | |||||
| #include "src/common/file_utils.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.h" | |||||
| #include "mindspore/lite/src/kernel_registry.h" | |||||
| namespace mindspore { | |||||
| class TestBiasGradFp32 : public mindspore::Common { | |||||
| public: | |||||
| TestBiasGradFp32() {} | |||||
| }; | |||||
| TEST_F(TestBiasGradFp32, BiasGradFp32) { | |||||
| // prepare stage | |||||
| auto bias_param = new ArithmeticParameter(); | |||||
| size_t input_size; | |||||
| std::string input_path = "./test_data/operators/biasgradfp32_1_dy_10_28_28_7.bin"; | |||||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||||
| std::vector<int> dim_dy({10, 28, 28, 7}); | |||||
| lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||||
| dy_tensor.SetData(input_data); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {&dy_tensor}; | |||||
| auto output_data = new float[7]; | |||||
| std::vector<int> dim_dw({7}); | |||||
| lite::tensor::Tensor dw_tensor(TypeId::kNumberTypeFloat32, dim_dw); | |||||
| dw_tensor.SetData(output_data); | |||||
| std::vector<lite::tensor::Tensor *> outputs = {&dw_tensor}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_BiasGrad}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(bias_param), NULL, desc); | |||||
| kernel_obj->Run(); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 7; i++) { | |||||
| std::cout << output_data[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/operators/biasgradfp32_1_db_7.bin"; | |||||
| lite::CompareOutput(output_data, output_path); | |||||
| // delete input_data; | |||||
| // delete[] output_data; | |||||
| delete bias_param; | |||||
| MS_LOG(INFO) << "BiasGradFp32 passed"; | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,521 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "utils/log_adapter.h" | |||||
| #include "common/common_test.h" | |||||
| #include "src/common/file_utils.h" | |||||
| #include "src/common/file_utils_ext.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/opclib/conv_parameter.h" | |||||
| #include "mindspore/lite/src/kernel_registry.h" | |||||
| namespace mindspore { | |||||
| class TestConvolutionGradFp32 : public mindspore::Common { | |||||
| public: | |||||
| TestConvolutionGradFp32() {} | |||||
| }; | |||||
| void InitConvParamGroup1FP32(ConvParameter *conv_param) { | |||||
| conv_param->input_batch_ = 1; | |||||
| conv_param->input_h_ = 28; | |||||
| conv_param->input_w_ = 28; | |||||
| conv_param->input_channel_ = 3; | |||||
| conv_param->output_batch_ = 1; | |||||
| conv_param->output_h_ = 28; | |||||
| conv_param->output_w_ = 28; | |||||
| conv_param->output_channel_ = 32; | |||||
| conv_param->kernel_h_ = 3; | |||||
| conv_param->kernel_w_ = 3; | |||||
| conv_param->stride_h_ = 1; | |||||
| conv_param->stride_w_ = 1; | |||||
| conv_param->dilation_h_ = 1; | |||||
| conv_param->dilation_w_ = 1; | |||||
| conv_param->pad_h_ = 1; | |||||
| conv_param->pad_w_ = 1; | |||||
| conv_param->group_ = 1; | |||||
| conv_param->is_relu_ = false; | |||||
| conv_param->is_relu6_ = false; | |||||
| conv_param->thread_num_ = 1; | |||||
| } | |||||
| void InitConvParamGroup3FP32(ConvParameter *conv_param) { | |||||
| InitConvParamGroup1FP32(conv_param); | |||||
| conv_param->group_ = 3; | |||||
| conv_param->output_channel_ = 18; | |||||
| } | |||||
| void InitConvParamGroup3Dilation2FP32(ConvParameter *conv_param) { | |||||
| InitConvParamGroup3FP32(conv_param); | |||||
| conv_param->dilation_h_ = 2; | |||||
| conv_param->dilation_w_ = 2; | |||||
| conv_param->output_h_ = 26; | |||||
| conv_param->output_w_ = 26; | |||||
| } | |||||
| TEST_F(TestConvolutionGradFp32, ConvFp32FilterGrad) { | |||||
| // prepare stage | |||||
| auto conv_param = new ConvParameter(); | |||||
| InitConvParamGroup1FP32(conv_param); | |||||
| size_t dy_size; | |||||
| std::string dy_path = "./test_data/conv/convfp32_dy_1_28_28_32.bin"; | |||||
| auto dy_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); | |||||
| std::vector<int> dim_dy({1, 28, 28, 32}); | |||||
| lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||||
| dy_tensor.SetData(dy_data); | |||||
| // runtime part | |||||
| printf("Calculating runtime cost...\n"); | |||||
| uint64_t time_avg = 0; | |||||
| size_t output_data_size = | |||||
| conv_param->output_channel_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | |||||
| size_t input_size; | |||||
| std::string input_path = "./test_data/conv/convfp32_x_1_28_28_3.bin"; | |||||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||||
| std::vector<int> dim_x({1, 28, 28, 3}); | |||||
| lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||||
| x_tensor.SetData(input_data); | |||||
| auto dw_data = new float[output_data_size]; | |||||
| std::vector<int> dim_dw({32, 3, 3, 3}); | |||||
| lite::tensor::Tensor dw_tensor(TypeId::kNumberTypeFloat32, dim_dw); | |||||
| dw_tensor.SetData(dw_data); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {&dy_tensor, &x_tensor}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {&dw_tensor}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradFilter}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), NULL, desc); | |||||
| // warm up loop | |||||
| for (int i = 0; i < 3; i++) { | |||||
| kernel->Run(); | |||||
| } | |||||
| int loop_count = 100; | |||||
| auto time_start = mindspore::lite::GetTimeUs(); | |||||
| for (int i = 0; i < loop_count; i++) { | |||||
| kernel->Run(); | |||||
| } | |||||
| auto time_end = mindspore::lite::GetTimeUs(); | |||||
| auto cost = time_end - time_start; | |||||
| time_avg = cost / loop_count; | |||||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||||
| std::string output_path = "./test_data/conv/convfp32_dw_32_3_3_3.bin"; | |||||
| auto res = lite::CompareRelativeOutput(dw_data, output_path); | |||||
| EXPECT_EQ(res, 0); | |||||
| // delete input_data; | |||||
| // delete dy_data; | |||||
| // delete [] dw_data; | |||||
| delete kernel; | |||||
| delete conv_param; | |||||
| MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; | |||||
| } | |||||
| TEST_F(TestConvolutionGradFp32, ConvFp32InputGrad) { | |||||
| // prepare stage | |||||
| auto conv_param = new ConvParameter(); | |||||
| InitConvParamGroup1FP32(conv_param); | |||||
| size_t dy_size; | |||||
| std::string dy_path = "./test_data/conv/convfp32_dy_1_28_28_32.bin"; | |||||
| auto dy_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); | |||||
| std::vector<int> dim_dy({1, 28, 28, 32}); | |||||
| lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||||
| dy_tensor.SetData(dy_data); | |||||
| size_t w_size; | |||||
| std::string w_path = "./test_data/conv/convfp32_w_32_3_3_3.bin"; | |||||
| auto w_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(w_path.c_str(), &w_size)); | |||||
| std::vector<int> dim_dw({32, 3, 3, 3}); | |||||
| lite::tensor::Tensor w_tensor(TypeId::kNumberTypeFloat32, dim_dw); | |||||
| w_tensor.SetData(w_data); | |||||
| size_t output_data_size = | |||||
| conv_param->input_batch_ * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; | |||||
| auto dx_data = new float[output_data_size]; | |||||
| std::vector<int> dim_dx({1, 28, 28, 3}); | |||||
| lite::tensor::Tensor dx_tensor(TypeId::kNumberTypeFloat32, dim_dx); | |||||
| dx_tensor.SetData(dx_data); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {&dy_tensor, &w_tensor}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {&dx_tensor}; | |||||
| // runtime part | |||||
| printf("Calculating runtime cost...\n"); | |||||
| uint64_t time_avg = 0; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradInput}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), NULL, desc); | |||||
| // warm up loop | |||||
| for (int i = 0; i < 3; i++) { | |||||
| kernel->Run(); | |||||
| } | |||||
| int loop_count = 100; | |||||
| auto time_start = mindspore::lite::GetTimeUs(); | |||||
| for (int i = 0; i < loop_count; i++) { | |||||
| kernel->Run(); | |||||
| } | |||||
| auto time_end = mindspore::lite::GetTimeUs(); | |||||
| auto cost = time_end - time_start; | |||||
| time_avg = cost / loop_count; | |||||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||||
| std::string output_path = "./test_data/conv/convfp32_dx_1_28_28_3.bin"; | |||||
| auto res = lite::CompareRelativeOutput(dx_data, output_path); | |||||
| EXPECT_EQ(res, 0); | |||||
| delete kernel; | |||||
| delete conv_param; | |||||
| MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; | |||||
| } | |||||
| TEST_F(TestConvolutionGradFp32, ConvFp32GroupFilterGrad) { | |||||
| // prepare stage | |||||
| auto conv_param = new ConvParameter(); | |||||
| InitConvParamGroup3FP32(conv_param); | |||||
| size_t dy_size; | |||||
| std::string dy_path = "./test_data/conv/convfp32_dy_g3_1_28_28_18.bin"; | |||||
| auto dy_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); | |||||
| std::vector<int> dim_dy({1, 28, 28, 18}); | |||||
| lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||||
| dy_tensor.SetData(dy_data); | |||||
| // runtime part | |||||
| printf("Calculating runtime cost...\n"); | |||||
| uint64_t time_avg = 0; | |||||
| size_t output_data_size = conv_param->output_channel_ * conv_param->kernel_h_ * conv_param->kernel_w_ * | |||||
| conv_param->input_channel_ / conv_param->group_; | |||||
| size_t input_size; | |||||
| std::string input_path = "./test_data/conv/convfp32_x_g3_1_28_28_3.bin"; | |||||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||||
| std::vector<int> dim_x({1, 28, 28, 3}); | |||||
| lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||||
| x_tensor.SetData(input_data); | |||||
| auto dw_data = new float[output_data_size]; | |||||
| std::vector<int> dim_dw({18, 3, 3, 1}); | |||||
| lite::tensor::Tensor dw_tensor(TypeId::kNumberTypeFloat32, dim_dw); | |||||
| dw_tensor.SetData(dw_data); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {&dy_tensor, &x_tensor}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {&dw_tensor}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradFilter}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), NULL, desc); | |||||
| // warm up loop | |||||
| for (int i = 0; i < 3; i++) { | |||||
| kernel->Run(); | |||||
| } | |||||
| int loop_count = 100; | |||||
| auto time_start = mindspore::lite::GetTimeUs(); | |||||
| for (int i = 0; i < loop_count; i++) { | |||||
| kernel->Run(); | |||||
| } | |||||
| auto time_end = mindspore::lite::GetTimeUs(); | |||||
| auto cost = time_end - time_start; | |||||
| time_avg = cost / loop_count; | |||||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||||
| std::string output_path = "./test_data/conv/convfp32_dw_g3_18_3_3_3.bin"; | |||||
| auto res = lite::CompareRelativeOutput(dw_data, output_path); | |||||
| EXPECT_EQ(res, 0); | |||||
| // delete input_data; | |||||
| // delete dy_data; | |||||
| // delete [] dw_data; | |||||
| delete kernel; | |||||
| delete conv_param; | |||||
| MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; | |||||
| } | |||||
| TEST_F(TestConvolutionGradFp32, ConvFp32GroupInputGrad) { | |||||
| // prepare stage | |||||
| auto conv_param = new ConvParameter(); | |||||
| InitConvParamGroup3FP32(conv_param); | |||||
| size_t dy_size; | |||||
| std::string dy_path = "./test_data/conv/convfp32_dy_g3_1_28_28_18.bin"; | |||||
| auto dy_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); | |||||
| std::vector<int> dim_dy({1, 28, 28, 18}); | |||||
| lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||||
| dy_tensor.SetData(dy_data); | |||||
| size_t w_size; | |||||
| std::string w_path = "./test_data/conv/convfp32_w_g3_18_3_3_3.bin"; | |||||
| auto w_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(w_path.c_str(), &w_size)); | |||||
| std::vector<int> dim_dw({18, 3, 3, 1}); | |||||
| lite::tensor::Tensor w_tensor(TypeId::kNumberTypeFloat32, dim_dw); | |||||
| w_tensor.SetData(w_data); | |||||
| size_t output_data_size = | |||||
| conv_param->input_batch_ * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; | |||||
| auto dx_data = new float[output_data_size]; | |||||
| std::vector<int> dim_dx({1, 28, 28, 3}); | |||||
| lite::tensor::Tensor dx_tensor(TypeId::kNumberTypeFloat32, dim_dx); | |||||
| dx_tensor.SetData(dx_data); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {&dy_tensor, &w_tensor}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {&dx_tensor}; | |||||
| // runtime part | |||||
| printf("Calculating runtime cost...\n"); | |||||
| uint64_t time_avg = 0; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradInput}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), NULL, desc); | |||||
| // warm up loop | |||||
| for (int i = 0; i < 3; i++) { | |||||
| kernel->Run(); | |||||
| } | |||||
| int loop_count = 100; | |||||
| auto time_start = mindspore::lite::GetTimeUs(); | |||||
| for (int i = 0; i < loop_count; i++) { | |||||
| kernel->Run(); | |||||
| } | |||||
| auto time_end = mindspore::lite::GetTimeUs(); | |||||
| auto cost = time_end - time_start; | |||||
| time_avg = cost / loop_count; | |||||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||||
| std::string output_path = "./test_data/conv/convfp32_dx_g3_1_28_28_3.bin"; | |||||
| auto res = lite::CompareRelativeOutput(dx_data, output_path); | |||||
| EXPECT_EQ(res, 0); | |||||
| delete kernel; | |||||
| delete conv_param; | |||||
| MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; | |||||
| } | |||||
| TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationFilterGrad) { | |||||
| // prepare stage | |||||
| auto conv_param = new ConvParameter(); | |||||
| InitConvParamGroup3Dilation2FP32(conv_param); | |||||
| size_t dy_size; | |||||
| std::string dy_path = "./test_data/conv/convfp32_dy_g3_d2_1_26_26_18.bin"; | |||||
| auto dy_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); | |||||
| std::vector<int> dim_dy({1, 26, 26, 18}); | |||||
| lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||||
| dy_tensor.SetData(dy_data); | |||||
| // runtime part | |||||
| printf("Calculating runtime cost...\n"); | |||||
| uint64_t time_avg = 0; | |||||
| size_t output_data_size = conv_param->output_channel_ * conv_param->kernel_h_ * conv_param->kernel_w_ * | |||||
| conv_param->input_channel_ / conv_param->group_; | |||||
| size_t input_size; | |||||
| std::string input_path = "./test_data/conv/convfp32_x_g3_d2_1_28_28_3.bin"; | |||||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||||
| std::vector<int> dim_x({1, 28, 28, 3}); | |||||
| lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||||
| x_tensor.SetData(input_data); | |||||
| auto dw_data = new float[output_data_size]; | |||||
| std::vector<int> dim_dw({18, 3, 3, 1}); | |||||
| lite::tensor::Tensor dw_tensor(TypeId::kNumberTypeFloat32, dim_dw); | |||||
| dw_tensor.SetData(dw_data); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {&dy_tensor, &x_tensor}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {&dw_tensor}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradFilter}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), NULL, desc); | |||||
| // warm up loop | |||||
| for (int i = 0; i < 3; i++) { | |||||
| kernel->Run(); | |||||
| } | |||||
| int loop_count = 100; | |||||
| auto time_start = mindspore::lite::GetTimeUs(); | |||||
| for (int i = 0; i < loop_count; i++) { | |||||
| kernel->Run(); | |||||
| } | |||||
| auto time_end = mindspore::lite::GetTimeUs(); | |||||
| auto cost = time_end - time_start; | |||||
| time_avg = cost / loop_count; | |||||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||||
| std::string output_path = "./test_data/conv/convfp32_dw_g3_d2_18_3_3_3.bin"; | |||||
| auto res = lite::CompareRelativeOutput(dw_data, output_path); | |||||
| EXPECT_EQ(res, 0); | |||||
| // delete input_data; | |||||
| // delete dy_data; | |||||
| // delete [] dw_data; | |||||
| delete kernel; | |||||
| delete conv_param; | |||||
| MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; | |||||
| } | |||||
| TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationInputGrad) { | |||||
| // prepare stage | |||||
| auto conv_param = new ConvParameter(); | |||||
| InitConvParamGroup3Dilation2FP32(conv_param); | |||||
| size_t dy_size; | |||||
| std::string dy_path = "./test_data/conv/convfp32_dy_g3_d2_1_26_26_18.bin"; | |||||
| auto dy_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); | |||||
| std::vector<int> dim_dy({1, 26, 26, 18}); | |||||
| lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||||
| dy_tensor.SetData(dy_data); | |||||
| size_t w_size; | |||||
| std::string w_path = "./test_data/conv/convfp32_w_g3_d2_18_3_3_3.bin"; | |||||
| auto w_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(w_path.c_str(), &w_size)); | |||||
| std::vector<int> dim_w({18, 3, 3, 1}); | |||||
| lite::tensor::Tensor w_tensor(TypeId::kNumberTypeFloat32, dim_w); | |||||
| w_tensor.SetData(w_data); | |||||
| size_t output_data_size = | |||||
| conv_param->input_batch_ * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; | |||||
| auto dx_data = new float[output_data_size]; | |||||
| std::vector<int> dim_dx({1, 28, 28, 3}); | |||||
| lite::tensor::Tensor dx_tensor(TypeId::kNumberTypeFloat32, dim_dx); | |||||
| dx_tensor.SetData(dx_data); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {&dy_tensor, &w_tensor}; | |||||
| std::vector<lite::tensor::Tensor *> outputs = {&dx_tensor}; | |||||
| // runtime part | |||||
| printf("Calculating runtime cost...\n"); | |||||
| uint64_t time_avg = 0; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradInput}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), NULL, desc); | |||||
| // warm up loop | |||||
| for (int i = 0; i < 3; i++) { | |||||
| kernel->Run(); | |||||
| } | |||||
| int loop_count = 100; | |||||
| auto time_start = mindspore::lite::GetTimeUs(); | |||||
| for (int i = 0; i < loop_count; i++) { | |||||
| kernel->Run(); | |||||
| } | |||||
| auto time_end = mindspore::lite::GetTimeUs(); | |||||
| auto cost = time_end - time_start; | |||||
| time_avg = cost / loop_count; | |||||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||||
| std::string output_path = "./test_data/conv/convfp32_dx_g3_d2_1_28_28_3.bin"; | |||||
| auto res = lite::CompareRelativeOutput(dx_data, output_path); | |||||
| EXPECT_EQ(res, 0); | |||||
| delete kernel; | |||||
| delete conv_param; | |||||
| MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; | |||||
| } | |||||
| // TEST_F(TestConvolutionGradFp32, ConvGroupDilation) { | |||||
| // // prepare stage | |||||
| // auto conv_param = new ConvParameter(); | |||||
| // InitConvParamGroup3Dilation2FP32(conv_param); | |||||
| // size_t x_size; | |||||
| // std::string x_path = "./test_data/conv/convfp32_x_g3_d2_1_28_28_3.bin"; | |||||
| // auto x_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(x_path.c_str(), &x_size)); | |||||
| // std::vector<int> dim_x({1, 28, 28, 3}); | |||||
| // tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||||
| // x_tensor.SetData(x_data); | |||||
| // size_t w_size; | |||||
| // std::string w_path = "./test_data/conv/convfp32_w_g3_d2_18_3_3_3.bin"; | |||||
| // auto w_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(w_path.c_str(), &w_size)); | |||||
| // std::vector<int> dim_w({18, 3, 3, 1}); | |||||
| // tensor::Tensor w_tensor(TypeId::kNumberTypeFloat32, dim_w); | |||||
| // w_tensor.SetData(w_data); | |||||
| // size_t output_data_size = | |||||
| // conv_param->output_batch_ * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; | |||||
| // auto y_data = new float[output_data_size]; | |||||
| // std::vector<int> dim_y({1, 26, 26, 18}); | |||||
| // tensor::Tensor y_tensor(TypeId::kNumberTypeFloat32, dim_y); | |||||
| // y_tensor.SetData(y_data); | |||||
| // std::vector<tensor::Tensor *> inputs = {&x_tensor, &w_tensor}; | |||||
| // std::vector<tensor::Tensor *> outputs = {&y_tensor}; | |||||
| // // runtime part | |||||
| // printf("Calculating runtime cost...\n"); | |||||
| // uint64_t time_avg = 0; | |||||
| // lite::Context context; | |||||
| // ; | |||||
| // context.deviceCtx.type = lite::DT_CPU; | |||||
| // context.threadNum = 1; | |||||
| // kernel::KernelKey desc = {kernel::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Conv2D}; | |||||
| // auto creator = lite::KernelRegistry::GetInstance()->GetKernelCreator(desc); | |||||
| // auto kernel = creator(inputs, outputs, (OpParameter *)conv_param, &context, desc); | |||||
| // kernel->train(); | |||||
| // EXPECT_EQ(kernel->is_train(), 1); | |||||
| // // warm up loop | |||||
| // for (int i = 0; i < 3; i++) { | |||||
| // kernel->Run(); | |||||
| // } | |||||
| // int loop_count = 100; | |||||
| // auto time_start = mindspore::lite::GetTimeUs(); | |||||
| // for (int i = 0; i < loop_count; i++) { | |||||
| // kernel->Run(); | |||||
| // } | |||||
| // auto time_end = mindspore::lite::GetTimeUs(); | |||||
| // auto cost = time_end - time_start; | |||||
| // time_avg = cost / loop_count; | |||||
| // printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||||
| // std::string output_path = "./test_data/conv/convfp32_y_g3_d2_1_26_26_18.bin"; | |||||
| // auto res = lite::CompareRelativeOutput(y_data, output_path); | |||||
| // EXPECT_EQ(res, 0); | |||||
| // delete kernel; | |||||
| // delete conv_param; | |||||
| // MS_LOG(INFO) << "TestConvolutionFp32 Filter Grad passed"; | |||||
| // } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,332 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include "mindspore/lite/include/context.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "common/common_test.h" | |||||
| #include "mindspore/lite/src/kernel_registry.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/common/file_utils.h" | |||||
| #include "src/runtime/kernel/arm/fp32/pooling_grad.h" | |||||
| #include "src/runtime/kernel/arm/opclib/fp32/pooling_grad.h" | |||||
| namespace mindspore { | |||||
| class TestPoolingGradFp32 : public mindspore::Common { | |||||
| public: | |||||
| TestPoolingGradFp32() {} | |||||
| }; | |||||
| void InitPoolingParamFP32(PoolingParameter *pooling_param) { | |||||
| pooling_param->input_batch_ = 1; | |||||
| pooling_param->input_h_ = 28; | |||||
| pooling_param->input_w_ = 28; | |||||
| pooling_param->input_channel_ = 3; | |||||
| pooling_param->output_batch_ = 1; | |||||
| pooling_param->output_h_ = 28; | |||||
| pooling_param->output_w_ = 28; | |||||
| pooling_param->output_channel_ = 32; | |||||
| pooling_param->window_h_ = 3; | |||||
| pooling_param->window_w_ = 3; | |||||
| pooling_param->stride_h_ = 1; | |||||
| pooling_param->stride_w_ = 1; | |||||
| pooling_param->pad_u_ = 1; | |||||
| pooling_param->pad_d_ = 1; | |||||
| pooling_param->pad_l_ = 1; | |||||
| pooling_param->pad_r_ = 1; | |||||
| pooling_param->thread_num_ = 1; | |||||
| } | |||||
| TEST_F(TestPoolingGradFp32, AvgPoolingGradFp32) { | |||||
| // prepare stage | |||||
| auto pooling_param = new PoolingParameter(); | |||||
| InitPoolingParamFP32(pooling_param); | |||||
| pooling_param->output_channel_ = 3; | |||||
| // runtime part | |||||
| printf("Calculating runtime cost...\n"); | |||||
| uint64_t time_avg = 0; | |||||
| size_t output_data_size = | |||||
| pooling_param->output_batch_ * pooling_param->output_channel_ * pooling_param->output_h_ * pooling_param->output_w_; | |||||
| size_t input_size; | |||||
| std::string input_path = "./test_data/pooling/avgpoolgradfp32_1_dy_1_28_28_3.bin"; | |||||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||||
| auto output_data = new float[output_data_size]; | |||||
| // warm up loop | |||||
| for (int i = 0; i < 3; i++) { | |||||
| AvgPoolingGrad(input_data, output_data, pooling_param); | |||||
| } | |||||
| int loop_count = 100; | |||||
| auto time_start = mindspore::lite::GetTimeUs(); | |||||
| for (int i = 0; i < loop_count; i++) { | |||||
| AvgPoolingGrad(input_data, output_data, pooling_param); | |||||
| } | |||||
| auto time_end = mindspore::lite::GetTimeUs(); | |||||
| auto cost = time_end - time_start; | |||||
| time_avg = cost / loop_count; | |||||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 20; i++) { | |||||
| std::cout << output_data[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_1_28_28_3.bin"; | |||||
| lite::CompareOutput(output_data, output_path); | |||||
| delete input_data; | |||||
| delete[] output_data; | |||||
| delete pooling_param; | |||||
| MS_LOG(INFO) << "TestAvgPoolingGradFp32 passed"; | |||||
| } | |||||
| TEST_F(TestPoolingGradFp32, AvgPoolingKernelGradFp32) { | |||||
| // prepare stage | |||||
| auto pooling_param = new PoolingParameter(); | |||||
| InitPoolingParamFP32(pooling_param); | |||||
| pooling_param->output_channel_ = 3; | |||||
| // runtime part | |||||
| printf("Calculating runtime cost...\n"); | |||||
| // uint64_t time_avg = 0; | |||||
| size_t output_data_size = | |||||
| pooling_param->output_batch_ * pooling_param->output_channel_ * pooling_param->output_h_ * pooling_param->output_w_; | |||||
| size_t input_size; | |||||
| std::string input_path = "./test_data/pooling/avgpoolgradfp32_1_dy_1_28_28_3.bin"; | |||||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||||
| std::vector<int> dim_dy({1, 28, 28, 3}); | |||||
| lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||||
| dy_tensor.SetData(input_data); | |||||
| std::string input1_path = "./test_data/pooling/avgpoolgradfp32_1_x_1_28_28_3.bin"; | |||||
| input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1_path.c_str(), &input_size)); | |||||
| std::vector<int> dim_x({1, 28, 28, 3}); | |||||
| lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||||
| x_tensor.SetData(input_data); | |||||
| std::vector<lite::tensor::Tensor *> inputs = {&dy_tensor, &x_tensor}; | |||||
| auto output_data = new float[output_data_size]; | |||||
| std::vector<int> dim_dx({1, 28, 28, 3}); | |||||
| lite::tensor::Tensor dx_tensor(TypeId::kNumberTypeFloat32, dim_dx); | |||||
| dx_tensor.SetData(output_data); | |||||
| std::vector<lite::tensor::Tensor *> outputs = {&dx_tensor}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(pooling_param), NULL, desc); | |||||
| kernel_obj->Run(); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 20; i++) { | |||||
| std::cout << output_data[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_1_28_28_3.bin"; | |||||
| lite::CompareOutput(output_data, output_path); | |||||
| // delete input_data; | |||||
| // delete[] output_data; | |||||
| delete pooling_param; | |||||
| MS_LOG(INFO) << "TestAvgPoolingGradFp32 passed"; | |||||
| } | |||||
| TEST_F(TestPoolingGradFp32, MaxPoolingGradFp32) { | |||||
| // prepare stage | |||||
| auto pooling_param = new PoolingParameter(); | |||||
| InitPoolingParamFP32(pooling_param); | |||||
| pooling_param->output_channel_ = 3; | |||||
| pooling_param->avg_pooling_ = false; | |||||
| pooling_param->max_pooling_ = true; | |||||
| // runtime part | |||||
| printf("Calculating runtime cost...\n"); | |||||
| uint64_t time_avg = 0; | |||||
| size_t output_data_size = | |||||
| pooling_param->output_batch_ * pooling_param->output_channel_ * pooling_param->output_h_ * pooling_param->output_w_; | |||||
| size_t input_size; | |||||
| std::string i_path = "./test_data/pooling/maxpoolgradfp32_1_i_1_28_28_3.bin"; | |||||
| auto ill_data = reinterpret_cast<int64_t *>(mindspore::lite::ReadFile(i_path.c_str(), &input_size)); | |||||
| auto i_data = new int[output_data_size]; | |||||
| for (uint32_t i = 0; i < output_data_size; i++) { | |||||
| i_data[i] = static_cast<int>(ill_data[i]); | |||||
| } | |||||
| std::string dy_path = "./test_data/pooling/maxpoolgradfp32_1_dy_1_28_28_3.bin"; | |||||
| auto dy_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dy_path.c_str(), &input_size)); | |||||
| auto output_data = new float[output_data_size]; | |||||
| // warm up loop | |||||
| for (int i = 0; i < 3; i++) { | |||||
| MaxPoolingGrad(dy_data, i_data, output_data, pooling_param); | |||||
| } | |||||
| int loop_count = 100; | |||||
| auto time_start = mindspore::lite::GetTimeUs(); | |||||
| for (int i = 0; i < loop_count; i++) { | |||||
| MaxPoolingGrad(dy_data, i_data, output_data, pooling_param); | |||||
| } | |||||
| auto time_end = mindspore::lite::GetTimeUs(); | |||||
| auto cost = time_end - time_start; | |||||
| time_avg = cost / loop_count; | |||||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 20; i++) { | |||||
| std::cout << output_data[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin"; | |||||
| lite::CompareOutput(output_data, output_path); | |||||
| // delete input_data; | |||||
| delete pooling_param; | |||||
| delete[] output_data; | |||||
| MS_LOG(INFO) << "TestMaxPoolingGradFp32 passed"; | |||||
| } | |||||
| #if 0 | |||||
| TEST_F(TestPoolingGradFp32, MaxPoolingKernelGradFp32) { | |||||
| // prepare stage | |||||
| auto maxpool = new PoolingParameter(); | |||||
| InitPoolingParamFP32(maxpool); | |||||
| maxpool->avg_pooling_ = false; | |||||
| maxpool->max_pooling_ = true; | |||||
| maxpool->input_h_ = 30; | |||||
| maxpool->input_w_ = 30; | |||||
| maxpool->input_channel_ = 3; | |||||
| maxpool->output_batch_ = 1; | |||||
| maxpool->output_h_ = 10; | |||||
| maxpool->output_w_ = 10; | |||||
| maxpool->output_channel_ = 3; | |||||
| maxpool->stride_h_ = 3; | |||||
| maxpool->stride_w_ = 3; | |||||
| maxpool->pad_u_ = 0; | |||||
| maxpool->pad_d_ = 0; | |||||
| maxpool->pad_l_ = 0; | |||||
| maxpool->pad_r_ = 0; | |||||
| size_t input_size; | |||||
| size_t y_data_size = maxpool->output_batch_ * maxpool->output_channel_ * maxpool->output_h_ * maxpool->output_w_; | |||||
| auto x_data = reinterpret_cast<float *>( | |||||
| mindspore::lite::ReadFile("./test_data/pooling/maxpoolgradfp32_2_x_1_30_30_3.bin", &input_size)); | |||||
| std::vector<int> dim_x({1, 30, 30, 3}); | |||||
| lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||||
| x_tensor.SetData(x_data); | |||||
| std::vector<lite::tensor::Tensor *> maxpool_inputs = {&x_tensor}; | |||||
| auto y_data = new float[y_data_size]; | |||||
| std::vector<int> dim_y({1, 10, 10, 3}); | |||||
| lite::tensor::Tensor y_tensor(TypeId::kNumberTypeFloat32, dim_y); | |||||
| y_tensor.SetData(y_data); | |||||
| auto ind_data = new int[y_data_size]; | |||||
| lite::tensor::Tensor ind_tensor(TypeId::kNumberTypeInt32, dim_y); | |||||
| ind_tensor.SetData(ind_data); | |||||
| std::vector<lite::tensor::Tensor *> maxpool_outputs = {&y_tensor, &ind_tensor}; | |||||
| kernel::KernelKey maxpool_desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Pooling}; | |||||
| auto maxpool_creator = lite::KernelRegistry::GetInstance()->GetCreator(maxpool_desc); | |||||
| auto maxpoolobj = maxpool_creator(maxpool_inputs, maxpool_outputs, reinterpret_cast<OpParameter *>(maxpool), | |||||
| NULL, maxpool_desc); | |||||
| maxpoolobj->Run(); | |||||
| printf("==================indices data=================\n"); | |||||
| for (int i = 0; i < 10; i++) { | |||||
| std::cout << ind_data[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| auto pooling_param = new PoolingParameter(); | |||||
| InitPoolingParamFP32(pooling_param); | |||||
| pooling_param->avg_pooling_ = false; | |||||
| pooling_param->max_pooling_ = true; | |||||
| pooling_param->input_h_ = 10; | |||||
| pooling_param->input_w_ = 10; | |||||
| pooling_param->input_channel_ = 3; | |||||
| pooling_param->output_batch_ = 1; | |||||
| pooling_param->output_h_ = 30; | |||||
| pooling_param->output_w_ = 30; | |||||
| pooling_param->output_channel_ = 3; | |||||
| // runtime part | |||||
| printf("Calculating runtime cost...\n"); | |||||
| // uint64_t time_avg = 0; | |||||
| size_t output_data_size = | |||||
| pooling_param->output_batch_ * pooling_param->output_channel_ * pooling_param->output_h_ * pooling_param->output_w_; | |||||
| auto dy_data = reinterpret_cast<float *>( | |||||
| mindspore::lite::ReadFile("./test_data/pooling/maxpoolgradfp32_2_dy_1_10_10_3.bin", &input_size)); | |||||
| std::vector<int> dim_dy({1, 3, 10, 10}); | |||||
| lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||||
| dy_tensor.SetData(dy_data); | |||||
| #if 0 | |||||
| std::string i_path = "./test_data/pooling/maxpoolgradfp32_2_i_1_3_10_10.bin"; | |||||
| auto ill_data = reinterpret_cast<int64_t*>(mindspore::lite::ReadFile(i_path.c_str(), &input_size)); | |||||
| auto i_data = new int[output_data_size]; | |||||
| for (int i=0; i < output_data_size; i++) | |||||
| i_data[i] = static_cast<int>(ill_data[i]); | |||||
| std::vector<int> dim_ind({1, 3, 10, 10}); | |||||
| lite::tensor::Tensor ind_tensor(TypeId::kNumberTypeInt32, dim_ind); | |||||
| ind_tensor.SetData(i_data); | |||||
| #endif | |||||
| std::vector<lite::tensor::Tensor *> inputs = {&dy_tensor, &ind_tensor}; | |||||
| auto output_data = new float[output_data_size]; | |||||
| std::vector<int> dim_dx({1, 3, 30, 30}); | |||||
| lite::tensor::Tensor dx_tensor(TypeId::kNumberTypeFloat32, dim_dx); | |||||
| dx_tensor.SetData(output_data); | |||||
| std::vector<lite::tensor::Tensor *> outputs = {&dx_tensor}; | |||||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; | |||||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(pooling_param), NULL, desc); | |||||
| kernel_obj->Run(); | |||||
| printf("==================output data=================\n"); | |||||
| for (int i = 0; i < 20; i++) { | |||||
| std::cout << output_data[i] << " ,"; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| std::string output_path = "./test_data/pooling/maxpoolgradfp32_2_dx_1_30_30_3.bin"; | |||||
| lite::CompareOutput(output_data, output_path); | |||||
| // delete input_data; | |||||
| // delete[] output_data; | |||||
| delete pooling_param; | |||||
| MS_LOG(INFO) << "TestMaxPoolingKernelGradFp32 passed"; | |||||
| } | |||||
| #endif // if 0 before MaxPoolingKernelGradFp32 | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1 @@ | |||||
| "x>#Ď>K9�>pR >)J >¤á4>™K>¤ĹZ>ńß>ŢÝ>ńńL>‚µ=ËQ>Ń*^>MÖ>&¶>6>Sş>đ*�>ÉN>Ë-ý=Ó+L>ÜvK>+A}>wě^>$ďQ>´Ús>ł/W>ó×Ď=Mţ'>9[*>#%†<#�>CÖ>>ˇ‚>$ÁŻ=Gţj>ňě>Ă7*>´Ă2>łĆ6>•™>ń1p>ős#>Y)>çôk>9ď÷=´ŘŔ=lQ0>ű—w> | |||||
| @@ -0,0 +1 @@ | |||||
| èM®?Ç·ú¾Ôå? ¿H2¿|Ý7>0á?dyX?C�.¿\fT¾¼@?ªÍ³¿Öö¾Àg?Lwñ¾«˜Å¾E�¾Š9&¿7AÎ?†T?öXF¿4Å?â–?žÒ¹?(k?´0?¬¤?¤VH?-–¿Tz@½&À²»Ç"-ÀÞ1¿£wñ¾šÕË?·Fº¿¼�?�«ç¾¢D¼¶Â>’øY>ãÌ¿_pœ?ÄØ¾í]ç¼ ç’?À%R¿5§¿KsË=ó? | |||||
| @@ -0,0 +1 @@ | |||||
| Ь╨?╢6V?ЯUл?╗ШS?=ОM?;╘┤?3P≤?;╓?ИоE?мLn?u╣≥?▐!?╠ЗV?═╕?sаW?9_?яe?}≈H?hюд?▌ ?XБ=?ч ≥?%≥≤?ЮП╫?Y1╖?[s²?Д╤?фc║?ЖА?tЩ{?у┬?╣7и=╣DK?eаW?щЯц?шп?╣>╟?kcY?╓S?г├?┴?_fQ?u%╢?П-u?≈}?╜В╟?kС9?┤╒?=└?Э╠╧? | |||||
| @@ -0,0 +1 @@ | |||||
| v╚≈=qиы╫ьBs>Эл╬лй╬─вQ=@ъ<U█и?P2╪4[Ф? | |||||
| @@ -0,0 +1 @@ | |||||
| /�ηΏΦ�ƒΏ™5±>Β"Ώ†¥\Ώ��Ώ=`ΞΏώ;�?σε»Ώ�¥©? | |||||
| @@ -0,0 +1 @@ | |||||
| грНїЖ&Н5SaНГЏА?tЩ?Џ�@WЕ,М�СО2еН&�8?;VММ�В?цЉЁ?$х�?5ЮљМpNђМъF7О:�Ё?�5VН:Ю�?їm Нѕ,!@фљО`|в>VўЕНи�№М ?_бB?0дНвУ"?уЬяН�!>%џ=Ћ�,?��Мѓ>Йа�?�;?ъqЙНGh�?7џЩ<бНР�U>=дх?р-Нaыp?Ђеg?��й>Цr@XА> | |||||
| @@ -0,0 +1 @@ | |||||
| âË]>Eí>òJn>bK>Œ8£=Š<¾P—>&”“>óg>]±Ð<WBX=S¯t>ä;�>ã¤Q>I�>¢¦º=\ƒ>ºéS> Å€=äC*>šK=ën>IyŠ>„Š“>¾l‚>/=—>rp>Ÿ‹”>ü«>Ž( >ûÁ[>-ï–>{ëj=Ç4’>C¾”>«eŽ>D”>B“=ü£=x”>/m�>v¾j>P–>Ävï=PÊ•>“·=�3>vN=œ �>Ó‚—> | |||||
| @@ -0,0 +1 @@ | |||||
| wА>У⌠╒>XOО>?Ая>┐│h>И©е=╦║%?╠o:?9╦И>"qЖ=Ы╗7>┤ У>??╣ в>9▄?а╢{> t?D2\?╫J>n·╦>╟▐1>хчО>█OF?/г?7y?J0?мeT?A?F$╕>'╬÷>Abъ>╧▄#?"m@>к<?⌡8?О?∙?ЩяZ>≈$i>╜8?Ю*C?│)Л>rО3?А├▓>СX?9y>©╙©>^С2>Sт??w!'? | |||||
| @@ -0,0 +1 @@ | |||||
| wa?ő“"?XOo??áQ?��č>éżE>¸ˇĄ?±oş?9¸i?"qv>ů¨·>‡ u??�?µ W?9Ś“?Á´ű>št�?D2Ü?˝Ę>nž8?°Ź±>ČŢo?ŤOĆ?/Ç™?7y‚?J°?ÍeÔ?Aś?F$&?'ľ?Ab_?ąŚŁ?"mŔ>ËĽ?›¸?ď�?•›?ýŃÚ>—$é>¸?ŕ*Ă?�)l?rďł?á†?óXź?9ů>żŞ??^ó˛>SÔż?w!§? | |||||
| @@ -0,0 +1,2 @@ | |||||
| „ù@?^û*¿Su>?‹“¿(1?YÙ?O]Í>8�©>yåͽâ³h>Y·:¿¬×Ÿ<e | |||||
| ?@?C?È‘¾®�C?6GU¾páž>¶_=¿I³¿0`Í>0¾>ÝŽ9?Úÿ;?Gs*>e3>£”?¯Ê‘>œ»;?(ô,?õ&¿3*ˆ¾©Ü?ŠŸC?çC¿w<2?š�ð=ôKý>%HC¿ß¾ñ%8?òMâ>£¥œ¾ºñ~>'uû>Jß¿ÙI>^4Y¾uZ?ó¿ | |||||
| @@ -0,0 +1,4 @@ | |||||
| ‹F ½.¸¿NÓ2¾œ³Õ?󻾩`?°�ͽÕåØ¿S¿Ä¯¿”2x¿¥R=}è%À–Tá?9¢>¾Ró?E„ÿ†Ö?ÇÕÜ>©´?®Á@< ƒ¿ˆ�*¿Fü‰¿sÑ?¹Ýh¿ˆêý>ðÛ¾i) ¿�W>+ Ò;ôÎä=y«@\ô?ð¿V=~?ú)о‡Ï¬?HF}?åžÕ¿˜Õ«?Fê“¿±E | |||||
| ?ŽG¿¾»Ã¿#μ¿>P¨D>‘¼È>J‹Å?gNð¾Y, <¹Öˆ¿�–u?Y_À"¶�¾ ñ4?À¢f¾ôœ¿x‹¾ YÀ7ü;¡̾…ÑÚ?Ö)™?£°©?€ì…?Ö]@-Ç/¼z²b?Áäï¿Y¸¾ñ e?MÖý¾ /6?¦"¿ë‡?œt«?ŽØT?; -¾½Ø1?,6?¿•¨¿.ª>nÉÞ>8«D?�Ǿ | |||||
| ãF¿Ö+j?~B? | |||||
| ê¿»¤P¾æ›Ç?Šœ?t ¾½ªek?›ûI?W³J>®ó&?æÑ ?;ñ;ƒéš¿Êj俾=¾í±¿È¾sg»?¡ÝÀ¿[kÄ?‚âr?Ý–c>.þ¾äÏl¼žjy?¥DÊ>S«î¼“¬Â¿º‡?lìÒ?rS¾Ùq´¾Åä?#m(@±_?>±¿l ‚¿Ž%6À˜¢?’<j¾¹z>h%¿Ðké¾>Þ=4 ù¿¨Œö?ªÅ‹¼o´J¾—û¾ �¿¥§s>•¾fW—?8c;?‚ä�?Æk»¾Š:º?bQ1¿ƒ¤Y>ypþ>½nW=úz¿|S:?P‚’?çrð?€KŠ?嘿k†µ¿wØâ>‘„-¼æ«>³Ví?Å~> | |||||
| @@ -0,0 +1 @@ | |||||
| .Àø@8�|À¸A-=,Ëú¿fAQÀ>2@du¤Ài}?ÿÃÀtÏ4@œˆš@2ªÊÀüzN@Ò¥¿Ô£x@Þ&ó½(Ó‚½‚¸e@�g�¼ | |||||
| @@ -0,0 +1 @@ | |||||
| Ñ&J¿vÿ�ÀB×AoLÝ?ˆÍƒÀI?ïÁ…óÌ¿ÜÐA6µÔ½ÏS¹¾?Á…—¸Áãþ�Àd>‘À5²®?±Çi¾’’Ó¿`ó@Úý ¾¿uº@G×Ý@`M>Aü¹Àv>B)¦ä>c”¾ÉÁ@$Ì/AwèŒA‡àÁŒˆË¿ß^ ¿ñ0ÁkܾÈ�£AfûÁèrñ>0xË¿€c¬ÀR†?ÂvuÁÊ=ˆ¿`,>pÅ”?aKÇÀªû@ðóîÁÓÀy?¼×¾ÿdb¿3ž?¹@Ú¤À¬�’¿eK?üÑÀЧ9¿)ÖÓ?u£u@ýþ?¡"á=ûP’>bæë>vá°@þNl@ÅÃ�ÀUî>otÀ ’ð?Æ*�@y ;³õ @´ÖAv‘©¿_„¿× ä¿¶�î¿\�¿™qØ?˜w@–0ÀÝ»¶¾©j€?òï÷ÀAq;Àb½oÀ¡Ž,¿Ò3@ûI`?3sfÀlø‡À€Ê@@IÚ?˜£§¿Ž·? ¡AC_>=ƒL�¿†Äó@„œÏ? .–¿ïà¢@xy`? ´á¿¿�ŒA3 3¿³Ë‘?Ãn�?.À¾ûžÛ=é\B@/õÀõAç>B_‚¿»�©¾Kø¿–Ç<wÀWcz=m9= | |||||
| @@ -0,0 +1 @@ | |||||
| œwÓ¿Öè�?8ðÖ>GF¼?-¢^À�.;?º£÷¿V]«?K“@zè:À} QÀ’Põ¾qÂH¿Í Å¿‚‘ã?¯Õ,@~Ð?LP§>7>ÄP«¿¥¼q@Ÿå¿P«š>ƒ¬·@ | |||||
| @@ -0,0 +1 @@ | |||||
| L&†À”œ¾h>)AÃ[7?Á.šÀO2Ê@ | |||||
| @@ -0,0 +1 @@ | |||||
| œwÓ¿Öè�?8ðÖ>GF¼?-¢^À�.;?º£÷¿V]«?K“@zè:À} QÀ’Põ¾qÂH¿Í Å¿‚‘ã?¯Õ,@~Ð?LP§>7>ÄP«¿¥¼q@Ÿå¿P«š>ƒ¬·@ | |||||
| @@ -0,0 +1 @@ | |||||
| ˆÁн.%Ž?ý•Ž¿.s?d¤“¿-f×=<V¾nß”?`Dz?�tŸ¿´"†¿?ýÿ¾e|»>�~�¾Þ¾šIÈ?ˆ ³??³Æ½ÌY?ü¼3¾B‘?ޝX½f¦<?ä'@ | |||||
| @@ -0,0 +1 @@ | |||||
| kBt?[[T> | |||||
| @@ -0,0 +1 @@ | |||||
| ‹þÀ2øòÀ ©)A%»„ÀdÈ‘À°º À | |||||
| @@ -0,0 +1 @@ | |||||
| ‹þÀ2øòÀ ©)A%»„ÀdÈ‘À°º À | |||||