| @@ -235,3 +235,4 @@ if (NOT WIN32) | |||
| endif () | |||
| include(${TOP_DIR}/cmake/package_lite.cmake) | |||
| @@ -17,37 +17,28 @@ | |||
| #define MINDSPORE_LITE_INCLUDE_TRAIN_SESSION_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <tuple> | |||
| #include <unordered_map> | |||
| #include "src/lite_session.h" | |||
| #include "include/lite_session.h" | |||
| #include "include/train_model.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| struct TrainModel; | |||
| } | |||
| namespace session { | |||
| class TrainSession : public lite::LiteSession { | |||
| public: | |||
| TrainSession(); | |||
| ~TrainSession(); | |||
| int RunGraph(const session::KernelCallBack &before = nullptr, | |||
| const session::KernelCallBack &after = nullptr) override; | |||
| int CompileGraph(lite::Model *model) override; | |||
| virtual void* ExportToBuf(char* buf, size_t* len) const; | |||
| class TrainSession : public session::LiteSession { | |||
| public: | |||
| virtual ~TrainSession() = default; | |||
| static TrainSession *CreateSession(lite::Context *context); | |||
| virtual void Train(); | |||
| virtual int CompileTrainGraph(lite::TrainModel *model) = 0; | |||
| virtual void *ExportToBuf(char *buf, size_t *len) const = 0; | |||
| virtual void Train() = 0; | |||
| bool IsTrain() { return train_mode_ == true; } | |||
| virtual void Eval(); | |||
| virtual void Eval() = 0; | |||
| bool IsEval() { return train_mode_ == false; } | |||
| protected: | |||
| virtual void ReplaceOps(); | |||
| bool train_mode_ = false; | |||
| lite::TrainModel *model_ = nullptr; | |||
| std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> orig_output_map_; | |||
| std::unordered_map<std::string, mindspore::tensor::MSTensor *> orig_output_tensor_map_; | |||
| }; | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -22,6 +22,7 @@ | |||
| typedef struct BatchNormParameter { | |||
| OpParameter op_parameter_; | |||
| float epsilon_; | |||
| float momentum_; | |||
| int unit_; | |||
| int units_; | |||
| int channel_; | |||
| @@ -54,22 +54,22 @@ void FusedBatchNormFp32(const void *input, const void *scale, const void *offset | |||
| } | |||
| } | |||
| void FusedBatchNormFp32MeanVar(const float *input, float momentum, float *run_mean, float *run_var, | |||
| BatchNormParameter *param, float *save_mean, float *save_inv_var) { | |||
| void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, BatchNormParameter *param, | |||
| float *save_mean, float *save_var) { | |||
| float N = (float)param->unit_; | |||
| for (int i = 0; i < param->unit_; i++) { | |||
| for (int f = 0; f < param->channel_; f++) { | |||
| int idx = i * param->channel_ + f; | |||
| run_mean[f] += input[idx]; | |||
| run_var[f] += input[idx] * input[idx]; | |||
| for (int c = 0; c < param->channel_; c++) { | |||
| int idx = i * param->channel_ + c; | |||
| run_mean[c] += input[idx]; | |||
| run_var[c] += input[idx] * input[idx]; | |||
| } | |||
| } | |||
| const float VN = (N > 1.0f) ? (N - 1.0f) : 1.0f; | |||
| for (int f = 0; f < param->channel_; f++) { | |||
| run_mean[f] = run_mean[f] / N; | |||
| run_var[f] = run_var[f] / VN - run_mean[f] * run_mean[f]; | |||
| save_mean[f] = momentum * save_mean[f] + (1 - momentum) * run_mean[f]; | |||
| const float inv_var = 1.f / sqrt(run_var[f] + param->epsilon_); | |||
| save_inv_var[f] = momentum * save_inv_var[f] + (1 - momentum) * inv_var; | |||
| for (int c = 0; c < param->channel_; c++) { | |||
| run_mean[c] = run_mean[c] / N; | |||
| run_var[c] = run_var[c] / VN - run_mean[c] * run_mean[c]; | |||
| save_mean[c] = param->momentum_ * save_mean[c] + (1 - param->momentum_) * run_mean[c]; | |||
| const float var = run_var[c]; | |||
| save_var[c] = param->momentum_ * save_var[c] + (1 - param->momentum_) * var; | |||
| } | |||
| } | |||
| @@ -28,8 +28,8 @@ void BatchNormFp32(const void *input, const void *mean, const void *variance, Ba | |||
| void FusedBatchNormFp32(const void *input, const void *scale, const void *offset, const void *mean, | |||
| const void *variance, BatchNormParameter *param, int task_id, void *output); | |||
| void FusedBatchNormFp32MeanVar(const float *input, float momentum, float *run_mean, float *run_var, | |||
| BatchNormParameter *param, float *save_mean, float *save_var); | |||
| void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, BatchNormParameter *param, | |||
| float *save_mean, float *save_var); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_OPTIMIZER_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP32_GRAD_OPTIMIZER_H_ | |||
| #include "nnacl/op_base.h" | |||
| typedef struct ApplyMomentumParameter { | |||
| OpParameter op_parameter_; | |||
| bool use_locking_; | |||
| bool use_nesterov_; | |||
| float grad_scale_; | |||
| } ApplyMomentumParameter; | |||
| typedef struct SgdParameter { | |||
| OpParameter op_parameter_; | |||
| float dampening_; | |||
| bool use_nesterov_; | |||
| float weight_decay_; | |||
| } SgdParameter; | |||
| #endif // MINDSPORE_LITE_NNACL_FP32_GRAD_OPTIMIZER_H_ | |||
| @@ -20,10 +20,8 @@ | |||
| static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); } | |||
| void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param) { | |||
| const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_l_; | |||
| // const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_; | |||
| const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_u_; | |||
| // const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_; | |||
| const int pad_left = conv_param->pad_l_; | |||
| const int pad_up = conv_param->pad_u_; | |||
| const int stride_h = conv_param->stride_h_; | |||
| const int stride_w = conv_param->stride_w_; | |||
| @@ -39,10 +37,11 @@ void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param | |||
| const int output_h = conv_param->output_h_; | |||
| const int output_w = conv_param->output_w_; | |||
| const int channels = conv_param->input_channel_ / conv_param->group_; | |||
| const int tot_channels = conv_param->input_channel_; | |||
| int /*channel,*/ kernel_row, kernel_col, output_rows, output_col; | |||
| int kernel_row, kernel_col, output_rows, output_col; | |||
| int row_stride_offset = 0; | |||
| @@ -71,11 +70,9 @@ void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param | |||
| } | |||
| // output matrix is (kernel_h*kernel_w*channels)X(output_h*output_w) | |||
| void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param) { | |||
| const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_l_; | |||
| // const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_; | |||
| const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_u_; | |||
| // const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_; | |||
| void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param, bool transpose) { | |||
| const int pad_left = conv_param->pad_l_; | |||
| const int pad_up = conv_param->pad_u_; | |||
| const int stride_h = conv_param->stride_h_; | |||
| const int stride_w = conv_param->stride_w_; | |||
| @@ -86,38 +83,67 @@ void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param | |||
| const int kernel_h = conv_param->kernel_h_; | |||
| const int kernel_w = conv_param->kernel_w_; | |||
| const int in_height = conv_param->input_h_; | |||
| const int in_width = conv_param->input_w_; | |||
| const int in_height = (transpose) ? conv_param->output_h_ : conv_param->input_h_; | |||
| const int in_width = (transpose) ? conv_param->output_w_ : conv_param->input_w_; | |||
| const int output_h = conv_param->output_h_; | |||
| const int output_w = conv_param->output_w_; | |||
| const int channels = conv_param->input_channel_ / conv_param->group_; | |||
| const int tot_channels = conv_param->input_channel_; | |||
| const int output_h = (transpose) ? conv_param->input_h_ : conv_param->output_h_; | |||
| const int output_w = (transpose) ? conv_param->input_w_ : conv_param->output_w_; | |||
| const int tot_channels = (transpose) ? conv_param->output_channel_ : conv_param->input_channel_; | |||
| const int channels = tot_channels / conv_param->group_; | |||
| int channel, kernel_row, kernel_col, output_rows, output_col; | |||
| for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { | |||
| for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | |||
| for (channel = 0; channel < channels; channel++) { | |||
| int input_row = -pad_up + kernel_row * dilation_h; | |||
| for (output_rows = output_h; output_rows; output_rows--) { | |||
| if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) { | |||
| for (output_col = output_w; output_col; output_col--) { | |||
| *(data_row++) = 0; | |||
| if (transpose) { | |||
| for (channel = 0; channel < channels; channel++) { | |||
| for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { | |||
| for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | |||
| int input_row = -pad_up + kernel_row * dilation_h; | |||
| for (output_rows = output_h; output_rows; output_rows--) { | |||
| if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) { | |||
| for (output_col = output_w; output_col; output_col--) { | |||
| *(data_row++) = 0; | |||
| } | |||
| } else { | |||
| int input_col = -pad_left + kernel_col * dilation_w; | |||
| for (output_col = output_w; output_col; output_col--) { | |||
| if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) { | |||
| const int offset = (input_row * in_width + input_col) * tot_channels + channel; | |||
| *(data_row++) = in_data[offset]; | |||
| } else { | |||
| *(data_row++) = 0; | |||
| } | |||
| input_col += stride_w; | |||
| } | |||
| } | |||
| } else { | |||
| int input_col = -pad_left + kernel_col * dilation_w; | |||
| for (output_col = output_w; output_col; output_col--) { | |||
| if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) { | |||
| const int offset = (input_row * in_width + input_col) * tot_channels + channel; | |||
| *(data_row++) = in_data[offset]; | |||
| } else { | |||
| input_row += stride_h; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { | |||
| for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | |||
| for (channel = 0; channel < channels; channel++) { | |||
| int input_row = -pad_up + kernel_row * dilation_h; | |||
| for (output_rows = output_h; output_rows; output_rows--) { | |||
| if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) { | |||
| for (output_col = output_w; output_col; output_col--) { | |||
| *(data_row++) = 0; | |||
| } | |||
| input_col += stride_w; | |||
| } else { | |||
| int input_col = -pad_left + kernel_col * dilation_w; | |||
| for (output_col = output_w; output_col; output_col--) { | |||
| if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) { | |||
| const int offset = (input_row * in_width + input_col) * tot_channels + channel; | |||
| *(data_row++) = in_data[offset]; | |||
| } else { | |||
| *(data_row++) = 0; | |||
| } | |||
| input_col += stride_w; | |||
| } | |||
| } | |||
| input_row += stride_h; | |||
| } | |||
| input_row += stride_h; | |||
| } | |||
| } | |||
| } | |||
| @@ -125,10 +151,8 @@ void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param | |||
| } | |||
| void col2im_hwc(const float *data_col, float *data_im, ConvParameter *conv_param) { | |||
| const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_l_; | |||
| // const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_; | |||
| const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_u_; | |||
| // const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_; | |||
| const int pad_left = conv_param->pad_l_; | |||
| const int pad_up = conv_param->pad_u_; | |||
| const int stride_h = conv_param->stride_h_; | |||
| const int stride_w = conv_param->stride_w_; | |||
| @@ -23,7 +23,7 @@ | |||
| extern "C" { | |||
| #endif | |||
| void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param); | |||
| void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param); | |||
| void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param, bool transpose); | |||
| void col2im_hwc(const float *data_col, float *data_im, ConvParameter *conv_param); | |||
| #ifdef __cplusplus | |||
| } | |||
| @@ -17,7 +17,7 @@ | |||
| #include <float.h> | |||
| #include "nnacl/fp32_grad/pooling_grad.h" | |||
| void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param) { | |||
| void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) { | |||
| int stride_w = pooling_param->stride_w_; | |||
| int stride_h = pooling_param->stride_h_; | |||
| int pad_w = pooling_param->pad_l_; | |||
| @@ -41,7 +41,7 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter | |||
| for (uint16_t yh = 0; yh < output_h; yh++) { | |||
| for (uint16_t yw = 0; yw < output_w; yw++) { | |||
| for (uint16_t ic = 0; ic < channel; ic++) { | |||
| int idx = (yw + yh * output_w) * channel + ic; // (ic*in_h*in_w) + (in_w*yh) + yw; | |||
| int idx = (yw + yh * output_w) * channel + ic; | |||
| float delta = inPtr[idx] / kk; | |||
| for (int32_t kh = 0; kh < win_h; kh++) { | |||
| int xh = yh * stride_h + kh - pad_h; | |||
| @@ -63,7 +63,7 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter | |||
| } | |||
| void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr, | |||
| PoolingParameter *pooling_param) { | |||
| PoolingParameter *pooling_param, int task_id) { | |||
| int stride_w = pooling_param->stride_w_; | |||
| int stride_h = pooling_param->stride_h_; | |||
| int pad_w = pooling_param->pad_l_; | |||
| @@ -22,9 +22,9 @@ | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param); | |||
| void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id); | |||
| void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr, | |||
| PoolingParameter *pooling_param); | |||
| PoolingParameter *pooling_param, int task_id); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -207,6 +207,7 @@ union PrimitiveType { | |||
| LshProjection, | |||
| HashtableLookup, | |||
| SkipGram, | |||
| DeConv2DGradFilter, | |||
| CustomPredict, | |||
| CustomNormalize, | |||
| CustomExtractFeatures, | |||
| @@ -215,6 +216,7 @@ union PrimitiveType { | |||
| Rfft, | |||
| FftReal, | |||
| FftImag, | |||
| Sgd, | |||
| } | |||
| enum QuantType: int { | |||
| @@ -407,6 +407,27 @@ table DeConv2D { | |||
| hasBias: bool = false; | |||
| activationType: ActivationType = 0; | |||
| } | |||
| table DeConv2DGradFilter { | |||
| format: Format = 0; | |||
| group: int; | |||
| channelIn: int; | |||
| channelOut: int; | |||
| kernelW: int; | |||
| kernelH: int; | |||
| strideW: int; | |||
| strideH: int; | |||
| padMode: PadMode; | |||
| padUp: int; | |||
| padDown: int; | |||
| padLeft: int; | |||
| padRight: int; | |||
| dilateW: int; | |||
| dilateH: int; | |||
| hasBias: bool = false; | |||
| activationType: ActivationType = 0; | |||
| } | |||
| table BNGrad { | |||
| eps : float; | |||
| momentum: float; | |||
| @@ -884,6 +905,11 @@ table ApplyMomentum { | |||
| useNesterov: bool; | |||
| } | |||
| table Sgd { | |||
| weightDecay: float; | |||
| dampening: float; | |||
| useNesterov: bool; | |||
| } | |||
| table Where{ | |||
| condition: [bool]; | |||
| @@ -45,7 +45,7 @@ int CompareRelativeOutput(float *output_data, std::string file_path) { | |||
| return 1; | |||
| } | |||
| size_t output_num = output_size / sizeof(float); | |||
| int error = CompareOutputRelativeData(output_data, ground_truth, output_num); | |||
| float error = CompareOutputRelativeData(output_data, ground_truth, output_num); | |||
| delete[] ground_truth; | |||
| if (error > 1e-4) { | |||
| return 1; | |||
| @@ -18,6 +18,22 @@ | |||
| #include <algorithm> | |||
| namespace mindspore::kernel { | |||
| void *LiteKernel::workspace_ = nullptr; | |||
| void LiteKernel::AllocWorkspace(size_t size) { | |||
| if (size == 0) return; | |||
| workspace_ = malloc(size); | |||
| if (workspace_ == nullptr) { | |||
| MS_LOG(ERROR) << "fail to alloc " << size; | |||
| } | |||
| } | |||
| void LiteKernel::FreeWorkspace() { | |||
| free(workspace_); | |||
| workspace_ = nullptr; | |||
| } | |||
| void LiteKernel::InitOutTensorRefCount() { | |||
| for (auto *tensor : this->out_tensors_) { | |||
| tensor->SetRefCount(this->out_kernels_.size()); | |||
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_LITE_SRC_LITE_KERNEL_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/common/utils.h" | |||
| #ifdef ENABLE_ARM | |||
| @@ -145,6 +146,11 @@ class LiteKernel { | |||
| void set_desc(const KernelKey kernel_key) { desc_ = kernel_key; } | |||
| const mindspore::lite::PrimitiveC *GetPrimitive() const { return primitive_; } | |||
| void SetWorkspaceSize(size_t value) { workspace_size_ = value; } | |||
| size_t GetWorkspaceSize() { return workspace_size_; } | |||
| static void AllocWorkspace(size_t size); | |||
| static void FreeWorkspace(); | |||
| void *GetWorkspace() { return workspace_; } | |||
| protected: | |||
| bool InferShapeDone() { return !(primitive_ != nullptr && !primitive_->GetInferFlag()) && true; } | |||
| @@ -161,6 +167,8 @@ class LiteKernel { | |||
| std::vector<LiteKernel *> out_kernels_; | |||
| bool train_mode_ = false; | |||
| bool is_model_output_ = false; | |||
| size_t workspace_size_ = 0; | |||
| static void *workspace_; | |||
| }; | |||
| class SubGraphKernel : public LiteKernel { | |||
| @@ -17,6 +17,10 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| float ApplyMomentum::GetGradientScale() const { return this->primitive_->value.AsApplyMomentum()->gradientScale; } | |||
| bool ApplyMomentum::GetUseLocking() const { return this->primitive_->value.AsApplyMomentum()->useLocking; } | |||
| bool ApplyMomentum::GetUseNesterov() const { return this->primitive_->value.AsApplyMomentum()->useNesterov; } | |||
| int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| @@ -36,6 +40,10 @@ int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePt | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->gradientScale = GetValue<float>(prim.GetAttr("gradient_scale")); | |||
| attr->useLocking = GetValue<bool>(prim.GetAttr("use_locking")); | |||
| attr->useNesterov = GetValue<bool>(prim.GetAttr("use_nesterov")); | |||
| this->primitive_->value.value = attr.release(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| @@ -45,6 +53,10 @@ int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePt | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| float ApplyMomentum::GetGradientScale() const { return this->primitive_->value_as_ApplyMomentum()->gradientScale(); } | |||
| bool ApplyMomentum::GetUseLocking() const { return this->primitive_->value_as_ApplyMomentum()->useLocking(); } | |||
| bool ApplyMomentum::GetUseNesterov() const { return this->primitive_->value_as_ApplyMomentum()->useNesterov(); } | |||
| int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| @@ -53,7 +65,7 @@ int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb | |||
| MS_LOG(ERROR) << "value_as_ApplyMomentum return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateApplyMomentum(*fbb); | |||
| auto val_offset = schema::CreateApplyMomentum(*fbb, attr->gradientScale(), attr->useLocking(), attr->useNesterov()); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ApplyMomentum, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| @@ -62,7 +74,7 @@ int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb | |||
| int ApplyMomentum::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | |||
| if (5 != inputs.size()) { | |||
| MS_LOG(ERROR) << "ApplyMomentum should have at 5 input tensors"; | |||
| MS_LOG(ERROR) << "ApplyMomentum should have at least 5 input tensors"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -76,6 +88,7 @@ int ApplyMomentum::InferShape(std::vector<lite::Tensor *> inputs, std::vector<li | |||
| MS_ASSERT(out != nullptr); | |||
| out->set_data_type(inputs[0]->data_type()); | |||
| out->SetFormat(inputs[0]->GetFormat()); | |||
| out->set_shape({1}); | |||
| } | |||
| return RET_OK; | |||
| @@ -39,6 +39,9 @@ class ApplyMomentum : public PrimitiveC { | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| float GetGradientScale() const; | |||
| bool GetUseLocking() const; | |||
| bool GetUseNesterov() const; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -89,6 +89,7 @@ int BiasGrad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> out | |||
| auto *out = outputs.front(); | |||
| MS_ASSERT(in0 != nullptr); | |||
| MS_ASSERT(out != nullptr); | |||
| auto inshape = in0->shape(); | |||
| int ndim = inshape.size(); | |||
| for (int i = 0; i < ndim - 1; i++) { | |||
| @@ -75,7 +75,7 @@ float BNGrad::GetEps() const { return this->primitive_->value_as_BNGrad()->eps() | |||
| float BNGrad::GetMomentum() const { return this->primitive_->value_as_BNGrad()->momentum(); } | |||
| #endif | |||
| int BNGrad::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | |||
| if (5 != inputs.size()) { | |||
| if (6 != inputs.size()) { | |||
| MS_LOG(ERROR) << "BNGrad should have five inputs"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -85,6 +85,7 @@ int BNGrad::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Ten | |||
| } | |||
| auto in = inputs[1]; | |||
| auto scale = inputs[2]; | |||
| outputs[0]->set_shape(in->shape()); | |||
| outputs[1]->set_shape(scale->shape()); | |||
| outputs[2]->set_shape(scale->shape()); | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_H_ | |||
| #ifndef MINDSPORE_LITE_SRC_OPS_BN_GRAD_H_ | |||
| #define MINDSPORE_LITE_SRC_OPS_BN_GRAD_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| @@ -44,4 +44,4 @@ class BNGrad : public PrimitiveC { | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ | |||
| #endif // MINDSPORE_LITE_SRC_OPS_BN_GRAD_H_ | |||
| @@ -73,5 +73,20 @@ float FusedBatchNorm::GetMomentum() const { return this->primitive_->value_as_Fu | |||
| int FusedBatchNorm::GetSpatial() const { return this->primitive_->value_as_FusedBatchNorm()->spatial(); } | |||
| #endif | |||
| int FusedBatchNorm::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { | |||
| for (size_t i = 0; i < inputs_.size(); i++) { | |||
| if (outputs_.size() <= i) break; | |||
| outputs_.at(i)->set_shape(inputs_.at(i)->shape()); | |||
| outputs_.at(i)->set_data_type(inputs_.at(i)->data_type()); | |||
| outputs_.at(i)->SetFormat(inputs_.at(i)->GetFormat()); | |||
| } | |||
| if (outputs_.size() > 5) { | |||
| outputs_.at(5)->set_data_type(inputs_.at(0)->data_type()); | |||
| outputs_.at(5)->SetFormat(inputs_.at(0)->GetFormat()); | |||
| outputs_.at(5)->set_shape({1}); | |||
| } | |||
| return 0; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -39,6 +39,7 @@ class FusedBatchNorm : public PrimitiveC { | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| float GetEpsilon() const; | |||
| float GetMomentum() const; | |||
| int GetSpatial() const; | |||
| @@ -145,7 +145,15 @@ int PoolingGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuf | |||
| #endif | |||
| int PoolingGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| MS_ASSERT(this->primitive != nullptr); | |||
| if (3 != inputs_.size()) { | |||
| MS_LOG(ERROR) << "Pooling Grad Filter should have 3 inputs"; | |||
| return RET_ERROR; | |||
| } | |||
| if (1 != outputs_.size()) { | |||
| MS_LOG(ERROR) << "Pooling Grad Filter should have one output"; | |||
| return RET_ERROR; | |||
| } | |||
| auto input = inputs_.at(0); | |||
| MS_ASSERT(input != nullptr); | |||
| int input_h = input->shape().at(1); | |||
| @@ -151,6 +151,7 @@ | |||
| #include "src/ops/depend.h" | |||
| #include "src/ops/flatten_grad.h" | |||
| #include "src/ops/log_grad.h" | |||
| #include "src/ops/sgd.h" | |||
| #endif | |||
| namespace mindspore { | |||
| @@ -384,7 +385,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<Dequant>(prim, inputs, quantType); | |||
| } else if (op_type == "Flatten") { | |||
| return NewPrimitiveC<Flatten>(prim, inputs, quantType); | |||
| } else if (op_type == "FusedBatchNorm") { | |||
| } else if ((op_type == "FusedBatchNorm") || (op_type == "FusedBatchNormEx")) { | |||
| return NewPrimitiveC<FusedBatchNorm>(prim, inputs, quantType); | |||
| } else if (op_type == "make_tuple") { | |||
| return NewPrimitiveC<MakeTuple>(prim, inputs, quantType); | |||
| @@ -452,7 +453,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<Conv2DGradFilter>(prim, inputs, quantType); | |||
| } else if (op_type == "Conv2DBackpropInput") { | |||
| return NewPrimitiveC<Conv2DGradInput>(prim, inputs, quantType); | |||
| } else if (op_type == "BatchNormGrad") { | |||
| } else if ((op_type == "BatchNormGrad") || (op_type == "FusedBatchNormGradEx")) { | |||
| return NewPrimitiveC<BNGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "FlattenGrad") { | |||
| return NewPrimitiveC<FlattenGrad>(prim, inputs, quantType); | |||
| @@ -460,6 +461,10 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<BNGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "Tile") { | |||
| return NewPrimitiveC<Tile>(prim, inputs, quantType); | |||
| } else if (op_type == "PowerGrad") { | |||
| return NewPrimitiveC<PowerGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "SGD") { | |||
| return NewPrimitiveC<Sgd>(prim, inputs, quantType); | |||
| #else | |||
| } else if (op_type == "Conv2DBackpropInput") { | |||
| return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); | |||
| @@ -731,6 +736,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new NegGrad(primitive); | |||
| case schema::PrimitiveType_LogGrad: | |||
| return new LogGrad(primitive); | |||
| case schema::PrimitiveType_Sgd: | |||
| return new Sgd(primitive); | |||
| #endif | |||
| default: | |||
| @@ -995,6 +1002,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) { | |||
| return NewPrimitiveC<NegGrad>(primitive); | |||
| case schema::PrimitiveType_LogGrad: | |||
| return NewPrimitiveC<LogGrad>(primitive); | |||
| case schema::PrimitiveType_Sgd: | |||
| return NewPrimitiveC<Sgd>(primitive); | |||
| #endif | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); | |||
| @@ -0,0 +1,97 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/sgd.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| float Sgd::GetWeightDecay() const { return this->primitive_->value.AsSgd()->weightDecay; } | |||
| float Sgd::GetDampening() const { return this->primitive_->value.AsSgd()->dampening; } | |||
| bool Sgd::GetUseNesterov() const { return this->primitive_->value.AsSgd()->useNesterov; } | |||
| int Sgd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Sgd; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Sgd) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = std::make_unique<schema::SgdT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->weightDecay = GetValue<float>(prim.GetAttr("weight_decay")); | |||
| attr->dampening = GetValue<float>(prim.GetAttr("dampening")); | |||
| attr->useNesterov = GetValue<bool>(prim.GetAttr("nesterov")); | |||
| this->primitive_->value.value = attr.release(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| float Sgd::GetWeightDecay() const { return this->primitive_->value_as_Sgd()->weightDecay(); } | |||
| float Sgd::GetDampening() const { return this->primitive_->value_as_Sgd()->dampening(); } | |||
| bool Sgd::GetUseNesterov() const { return this->primitive_->value_as_Sgd()->useNesterov(); } | |||
| int Sgd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_Sgd(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_Sgd return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateSgd(*fbb, attr->weightDecay(), attr->dampening(), attr->useNesterov()); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Sgd, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| #endif | |||
| int Sgd::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | |||
| if (6 != inputs.size()) { | |||
| MS_LOG(ERROR) << "Sgd should have at least 6 input tensors"; | |||
| return RET_ERROR; | |||
| } | |||
| if (inputs[0]->ElementsNum() != inputs[1]->ElementsNum() || inputs[0]->ElementsNum() != inputs[3]->ElementsNum() || | |||
| inputs[2]->ElementsNum() != 1 || inputs[4]->ElementsNum() != 1) { | |||
| MS_LOG(ERROR) << "error input data size!"; | |||
| return RET_ERROR; | |||
| } | |||
| if (!outputs.empty()) { | |||
| auto *out = outputs.front(); | |||
| MS_ASSERT(out != nullptr); | |||
| out->set_data_type(inputs[0]->data_type()); | |||
| out->SetFormat(inputs[0]->GetFormat()); | |||
| out->set_shape({1}); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_OPS_SGD_H_ | |||
| #define MINDSPORE_LITE_SRC_OPS_SGD_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include <memory> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Sgd : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Sgd, PrimitiveC); | |||
| Sgd() = default; | |||
| explicit Sgd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| Sgd() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| float GetWeightDecay() const; | |||
| float GetDampening() const; | |||
| bool GetUseNesterov() const; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_OPS_SGD_H_ | |||
| @@ -633,6 +633,7 @@ OpParameter *PopulateFusedBatchNorm(const mindspore::lite::PrimitiveC *primitive | |||
| auto param = | |||
| reinterpret_cast<mindspore::lite::FusedBatchNorm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||
| batch_norm_param->epsilon_ = param->GetEpsilon(); | |||
| batch_norm_param->momentum_ = param->GetMomentum(); | |||
| batch_norm_param->fused_ = true; | |||
| return reinterpret_cast<OpParameter *>(batch_norm_param); | |||
| } | |||
| @@ -37,6 +37,14 @@ void FusedBatchnormCPUKernel::FreeScaleAndOffset() { | |||
| free(offset_); | |||
| offset_ = nullptr; | |||
| } | |||
| if (save_mean_ != nullptr) { | |||
| free(save_mean_); | |||
| save_mean_ = nullptr; | |||
| } | |||
| if (save_variance_ != nullptr) { | |||
| free(save_variance_); | |||
| save_variance_ = nullptr; | |||
| } | |||
| } | |||
| int FusedBatchnormCPUKernel::InitConstTensor() { | |||
| @@ -49,8 +57,11 @@ int FusedBatchnormCPUKernel::InitConstTensor() { | |||
| offset_ = malloc(offset->Size()); | |||
| mean_ = malloc(mean->Size()); | |||
| variance_ = malloc(variance->Size()); | |||
| save_mean_ = malloc(mean->Size()); | |||
| save_variance_ = malloc(variance->Size()); | |||
| if (scale_ == nullptr || offset_ == nullptr || mean_ == nullptr || variance_ == nullptr) { | |||
| if (scale_ == nullptr || offset_ == nullptr || mean_ == nullptr || variance_ == nullptr || save_mean_ == nullptr || | |||
| save_variance_ == nullptr) { | |||
| FreeMeanAndVariance(); | |||
| FreeScaleAndOffset(); | |||
| MS_LOG(ERROR) << "Memory allocation failed"; | |||
| @@ -60,6 +71,15 @@ int FusedBatchnormCPUKernel::InitConstTensor() { | |||
| memcpy(offset_, offset->MutableData(), offset->Size()); | |||
| memcpy(mean_, mean->MutableData(), mean->Size()); | |||
| memcpy(variance_, variance->MutableData(), variance->Size()); | |||
| memset(save_mean_, 0, mean->Size()); | |||
| memset(save_variance_, 0, variance->Size()); | |||
| if (out_tensors_.size() > 4) { | |||
| for (size_t i = 1; i < out_tensors_.size(); i++) { | |||
| auto *data = static_cast<float *>(out_tensors_[i]->MutableData()); | |||
| std::fill(data, data + out_tensors_[i]->ElementsNum(), 0.f); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -70,15 +90,23 @@ int FusedBatchnormCPUKernel::Run() { | |||
| return ret; | |||
| } | |||
| auto param = reinterpret_cast<BatchNormParameter *>(op_parameter_); | |||
| if (is_train()) { | |||
| if (is_train() && in_tensors_.size() >= 5) { | |||
| float *in = static_cast<float *>(in_tensors_[0]->MutableData()); | |||
| float *run_mean = static_cast<float *>(out_tensors_[1]->MutableData()); | |||
| float *run_var = static_cast<float *>(out_tensors_[2]->MutableData()); | |||
| float *save_mean = static_cast<float *>(out_tensors_[3]->MutableData()); | |||
| float *save_inv_var = static_cast<float *>(out_tensors_[4]->MutableData()); | |||
| std::fill(run_mean, run_mean + param->channel_, 0.f); | |||
| std::fill(run_var, run_var + param->channel_, 0.f); | |||
| FusedBatchNormFp32MeanVar(in, 0.9, run_mean, run_var, param, save_mean, save_inv_var); | |||
| float *scale = static_cast<float *>(in_tensors_[1]->MutableData()); | |||
| float *bias = static_cast<float *>(in_tensors_[2]->MutableData()); | |||
| float *mean = static_cast<float *>(in_tensors_[3]->MutableData()); | |||
| float *var = static_cast<float *>(in_tensors_[4]->MutableData()); | |||
| std::fill(mean, mean + in_tensors_[3]->ElementsNum(), 0.f); | |||
| std::fill(var, var + in_tensors_[4]->ElementsNum(), 0.f); | |||
| FusedBatchNormFp32MeanVar(in, mean, var, param, static_cast<float *>(save_mean_), | |||
| static_cast<float *>(save_variance_)); | |||
| memcpy(out_tensors_[3]->MutableData(), save_mean_, out_tensors_[3]->Size()); | |||
| memcpy(out_tensors_[4]->MutableData(), save_variance_, out_tensors_[3]->Size()); | |||
| memcpy(mean_, mean, in_tensors_[3]->Size()); | |||
| memcpy(variance_, var, in_tensors_[4]->Size()); | |||
| memcpy(scale_, scale, in_tensors_[1]->Size()); | |||
| memcpy(offset_, bias, in_tensors_[2]->Size()); | |||
| trained_ = true; // trained at least once | |||
| } | |||
| ret = ParallelLaunch(this->context_->thread_pool_, BatchNormRun, this, op_parameter_->thread_num_); | |||
| if (ret != RET_OK) { | |||
| @@ -87,6 +115,24 @@ int FusedBatchnormCPUKernel::Run() { | |||
| return ret; | |||
| } | |||
| void FusedBatchnormCPUKernel::eval() { | |||
| LiteKernel::eval(); | |||
| if (trained_) { | |||
| float *run_mean = static_cast<float *>(in_tensors_[3]->MutableData()); | |||
| float *run_var = static_cast<float *>(in_tensors_[4]->MutableData()); | |||
| float *scale = static_cast<float *>(in_tensors_[1]->MutableData()); | |||
| float *bias = static_cast<float *>(in_tensors_[2]->MutableData()); | |||
| // Copy to input tensors for Model export | |||
| memcpy(run_mean, save_mean_, in_tensors_[3]->Size()); | |||
| memcpy(run_var, save_variance_, in_tensors_[4]->Size()); | |||
| // Copy to local variables | |||
| memcpy(mean_, run_mean, in_tensors_[3]->Size()); | |||
| memcpy(variance_, run_var, in_tensors_[4]->Size()); | |||
| memcpy(scale_, scale, in_tensors_[1]->Size()); | |||
| memcpy(offset_, bias, in_tensors_[2]->Size()); | |||
| } | |||
| } | |||
| int FusedBatchnormCPUKernel::DoExecute(int task_id) { | |||
| auto param = reinterpret_cast<BatchNormParameter *>(op_parameter_); | |||
| FusedBatchNormFp32(in_tensors_.at(0)->MutableData(), scale_, offset_, mean_, variance_, param, task_id, | |||
| @@ -29,6 +29,7 @@ class FusedBatchnormCPUKernel : public BatchnormCPUKernel { | |||
| : BatchnormCPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~FusedBatchnormCPUKernel() { FreeScaleAndOffset(); } | |||
| void eval() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int InitConstTensor() override; | |||
| @@ -38,6 +39,9 @@ class FusedBatchnormCPUKernel : public BatchnormCPUKernel { | |||
| void FreeScaleAndOffset(); | |||
| void *scale_ = nullptr; | |||
| void *offset_ = nullptr; | |||
| void *save_mean_ = nullptr; | |||
| void *save_variance_ = nullptr; | |||
| bool trained_ = false; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -32,7 +32,13 @@ using mindspore::schema::ActivationType_RELU6; | |||
| using mindspore::schema::PrimitiveType_ActivationGrad; | |||
| namespace mindspore::kernel { | |||
| int ActivationGradCPUKernel::Init() { return RET_OK; } | |||
| int ActivationGradCPUKernel::Init() { | |||
| if (2 != in_tensors_.size()) { | |||
| MS_LOG(ERROR) << "ActivationGrad should have 2 input tensors"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ActivationGradCPUKernel::ReSize() { return RET_OK; } | |||
| @@ -42,22 +48,32 @@ int ActivationGradCPUKernel::DoActivation(int task_id) { | |||
| auto output_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||
| int length = in_tensors_.at(0)->ElementsNum(); | |||
| int stride = UP_DIV(length, thread_count_); | |||
| int count = MSMIN(stride, length - stride * task_id); | |||
| auto error_code = RET_OK; | |||
| if (param_act_grad_->type_ == schema::ActivationType_RELU) { | |||
| error_code = ReluGrad(yt_addr, input_addr, length, output_addr); | |||
| error_code = | |||
| ReluGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); | |||
| } else if (param_act_grad_->type_ == schema::ActivationType_RELU6) { | |||
| error_code = Relu6Grad(yt_addr, input_addr, length, output_addr); | |||
| error_code = | |||
| Relu6Grad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); | |||
| } else if (param_act_grad_->type_ == schema::ActivationType_LEAKY_RELU) { | |||
| error_code = LReluGrad(yt_addr, input_addr, length, output_addr, param_act_grad_->alpha_); | |||
| error_code = LReluGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, | |||
| output_addr + stride * task_id, param_act_grad_->alpha_); | |||
| } else if (param_act_grad_->type_ == schema::ActivationType_SIGMOID) { | |||
| error_code = SigmoidGrad(yt_addr, input_addr, length, output_addr); | |||
| error_code = | |||
| SigmoidGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); | |||
| } else if (param_act_grad_->type_ == schema::ActivationType_TANH) { | |||
| error_code = TanhGrad(yt_addr, input_addr, length, output_addr); | |||
| error_code = | |||
| TanhGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); | |||
| } else if (param_act_grad_->type_ == schema::ActivationType_HSWISH) { | |||
| error_code = HSwishGrad(yt_addr, input_addr, length, output_addr); | |||
| error_code = | |||
| HSwishGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); | |||
| } else if (param_act_grad_->type_ == schema::ActivationType_HSIGMOID) { | |||
| error_code = HSigmoidGrad(yt_addr, input_addr, length, output_addr); | |||
| error_code = | |||
| HSigmoidGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); | |||
| } else { | |||
| MS_LOG(ERROR) << "Activation type error"; | |||
| return RET_ERROR; | |||
| @@ -81,13 +97,13 @@ int ActivationGradRun(void *cdata, int task_id) { | |||
| int ActivationGradCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare failed."; | |||
| MS_LOG(ERROR) << "ActivationGradCPUKernel Prepare failed."; | |||
| return ret; | |||
| } | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ActivationGradRun, this, thread_count_); | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ActivationGradRun, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "Activation function error error_code[" << error_code << "]"; | |||
| MS_LOG(ERROR) << "Activation Grad function error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| @@ -107,7 +123,7 @@ kernel::LiteKernel *CpuActivationGradFp32KernelCreator(const std::vector<lite::T | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "InferShape kernel failed, name: " << opParameter->name_ << ", type: " | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| delete kernel; | |||
| return nullptr; | |||
| @@ -19,6 +19,7 @@ | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/fp32/nchw2nhwc.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| @@ -31,13 +32,7 @@ namespace mindspore::kernel { | |||
| int ApplyMomentumCPUKernel::ReSize() { return RET_OK; } | |||
| int ApplyMomentumCPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| int ApplyMomentumCPUKernel::Execute(int task_id) { | |||
| auto weight = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | |||
| auto accumulate = reinterpret_cast<float *>(in_tensors_[1]->MutableData()); | |||
| float learning_rate = reinterpret_cast<float *>(in_tensors_[2]->MutableData())[0]; | |||
| @@ -45,9 +40,41 @@ int ApplyMomentumCPUKernel::Run() { | |||
| float moment = reinterpret_cast<float *>(in_tensors_[4]->MutableData())[0]; | |||
| size_t elem_num = in_tensors_[0]->ElementsNum(); | |||
| for (size_t i = 0; i < elem_num; ++i) { | |||
| accumulate[i] = accumulate[i] * moment + gradient[i]; // * (1.0 - moment); | |||
| weight[i] -= accumulate[i] * learning_rate; | |||
| if (apply_momentum_param_->use_nesterov_) { | |||
| for (size_t i = 0; i < elem_num; ++i) { | |||
| accumulate[i] = accumulate[i] * moment + gradient[i]; | |||
| weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; | |||
| } | |||
| } else { | |||
| for (size_t i = 0; i < elem_num; ++i) { | |||
| accumulate[i] = accumulate[i] * moment + gradient[i]; | |||
| weight[i] -= accumulate[i] * learning_rate; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ApplyMomentumRun(void *cdata, int task_id) { | |||
| auto applyMomentum_kernel = reinterpret_cast<ApplyMomentumCPUKernel *>(cdata); | |||
| auto error_code = applyMomentum_kernel->Execute(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "apply Momentum run error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ApplyMomentumCPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ApplyMomentumCPUKernel Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ApplyMomentumRun, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "Apply Momentum function error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -77,6 +104,7 @@ kernel::LiteKernel *CpuApplyMomentumFp32KernelCreator(const std::vector<lite::Te | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| @@ -19,6 +19,7 @@ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "nnacl/fp32_grad/optimizer.h" | |||
| namespace mindspore::kernel { | |||
| class ApplyMomentumCPUKernel : public LiteKernel { | |||
| @@ -26,11 +27,17 @@ class ApplyMomentumCPUKernel : public LiteKernel { | |||
| explicit ApplyMomentumCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), apply_momentum_param_(nullptr) { | |||
| apply_momentum_param_ = reinterpret_cast<ApplyMomentumParameter *>(parameter); | |||
| } | |||
| ~ApplyMomentumCPUKernel() override {} | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int Execute(int task_id); | |||
| private: | |||
| ApplyMomentumParameter *apply_momentum_param_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -20,6 +20,7 @@ | |||
| #include "nnacl/fp32_grad/reduce_grad.h" | |||
| #include "nnacl/fp32_grad/arithmetic_grad.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -36,14 +37,13 @@ int ArithmeticGradCPUKernel::Init() { | |||
| MS_ASSERT(dx2 != nullptr); | |||
| if ((Type() == PrimitiveType_MulGrad) || (Type() == PrimitiveType_DivGrad)) { | |||
| // if (inShape0.size() < inShape1.size()) | |||
| if (dx1->ElementsNum() < dx2->ElementsNum()) { | |||
| if (Type() == PrimitiveType_MulGrad) | |||
| arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul2L; | |||
| else if (Type() == PrimitiveType_DivGrad) | |||
| arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradDiv2L; | |||
| } else if (dx2->ElementsNum() < dx1->ElementsNum()) { // if (inShape0.size() > inShape1.size()) | |||
| } else if (dx2->ElementsNum() < dx1->ElementsNum()) { | |||
| if (Type() == PrimitiveType_MulGrad) | |||
| arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul1L; | |||
| else if (Type() == PrimitiveType_DivGrad) | |||
| @@ -157,7 +157,6 @@ void ArithmeticGradCPUKernel::ArithmeticGradDiv1L(float *dy, int dy_size, float | |||
| ReduceSumByAxes(tile_data2, arithmeticParameter_->in_shape0_, dx2, arithmeticParameter_->in_shape1_, | |||
| arithmeticParameter_->ndim_); | |||
| for (int i = 0; i < dx2_size; i++) dx2[i] = -dx2[i]; | |||
| // ReduceNegSumPrefix(tile_data2, dy_size, dx2, dx2_size); //then reduce into dx2 | |||
| // broadcasting x2 | |||
| BroadcastDiv(dy, x2_data, tile_data0, tile_data1, dx1, dy_size, arithmeticParameter_); // broadcast directly to dx1 | |||
| @@ -180,7 +179,7 @@ void ArithmeticGradCPUKernel::ArithmeticGradDiv2L(float *dy, int dy_size, float | |||
| int ArithmeticGradCPUKernel::ReSize() { return RET_OK; } | |||
| int ArithmeticGradCPUKernel::Run() { | |||
| int ArithmeticGradCPUKernel::Execute(int task_id) { | |||
| auto dy = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | |||
| auto dx1 = reinterpret_cast<float *>(out_tensors_[0]->MutableData()); | |||
| auto dx2 = reinterpret_cast<float *>(out_tensors_[1]->MutableData()); | |||
| @@ -192,6 +191,30 @@ int ArithmeticGradCPUKernel::Run() { | |||
| return RET_OK; | |||
| } | |||
| int ArithmeticGradRun(void *cdata, int task_id) { | |||
| auto Arithmetic_kernel = reinterpret_cast<ArithmeticGradCPUKernel *>(cdata); | |||
| auto error_code = Arithmetic_kernel->Execute(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "ArithmeticGradRun error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ArithmeticGradCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ArithmeticGradCPUKernel Prepare failed."; | |||
| return ret; | |||
| } | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ArithmeticGradRun, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "Arithmetic Grad function error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuArithmeticGradFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::InnerContext *ctx, | |||
| @@ -68,6 +68,7 @@ class ArithmeticGradCPUKernel : public LiteKernel { | |||
| int InferShape(); | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int Execute(int task_id); | |||
| private: | |||
| void ArithmeticGradAdd(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size); | |||
| @@ -19,6 +19,7 @@ | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -43,14 +44,9 @@ int BiasGradCPUKernel::Init() { | |||
| return RET_OK; | |||
| } | |||
| int BiasGradCPUKernel::ReSize() { return 0; } | |||
| int BiasGradCPUKernel::ReSize() { return RET_OK; } | |||
| int BiasGradCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare failed."; | |||
| return RET_ERROR; | |||
| } | |||
| int BiasGradCPUKernel::Execute(int task_id) { | |||
| auto in = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||
| auto out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||
| @@ -69,6 +65,30 @@ int BiasGradCPUKernel::Run() { | |||
| return RET_OK; | |||
| } | |||
| int BiasGradRun(void *cdata, int task_id) { | |||
| auto bias_kernel = reinterpret_cast<BiasGradCPUKernel *>(cdata); | |||
| auto error_code = bias_kernel->Execute(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "bias error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int BiasGradCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "BiasGradCPUKernel Prepare failed."; | |||
| return RET_ERROR; | |||
| } | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, BiasGradRun, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "bias function error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuBiasGradFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| @@ -35,6 +35,7 @@ class BiasGradCPUKernel : public LiteKernel { | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int Execute(int task_id); | |||
| private: | |||
| ArithmeticParameter *bias_param; | |||
| @@ -21,6 +21,7 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "nnacl/fp32_grad/batch_norm.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -33,23 +34,13 @@ namespace mindspore::kernel { | |||
| int BNGradCPUKernel::Init() { | |||
| auto *input_x = in_tensors_.at(1); | |||
| int channels = input_x->shape().at(kNHWC_C); | |||
| workspace_size = 4 * channels; | |||
| workspace = new (std::nothrow) float[workspace_size]; | |||
| if (workspace == nullptr) { | |||
| MS_LOG(ERROR) << "new workspace fail!"; | |||
| return RET_ERROR; | |||
| } | |||
| SetWorkspaceSize(4 * channels * sizeof(float)); | |||
| return RET_OK; | |||
| } | |||
| int BNGradCPUKernel::ReSize() { return RET_OK; } | |||
| int BNGradCPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| int BNGradCPUKernel::Execute(int task_id) { | |||
| auto bn_param = reinterpret_cast<BNGradParameter *>(op_parameter_); | |||
| auto *input_yt = in_tensors_.at(0); | |||
| auto *input_x = in_tensors_.at(1); | |||
| @@ -61,7 +52,9 @@ int BNGradCPUKernel::Run() { | |||
| int channels = input_x->Channel(); | |||
| int spatial = input_x->Height() * input_x->Width(); | |||
| float eps = bn_param->epsilon_; | |||
| std::fill(workspace, workspace + workspace_size, 0.f); | |||
| float *workspace = static_cast<float *>(GetWorkspace()); | |||
| std::fill(workspace, workspace + GetWorkspaceSize() / sizeof(*workspace), 0.f); | |||
| float *mean = workspace; | |||
| float *invar = mean + channels; | |||
| float *dxhat_sum = invar + channels; | |||
| @@ -82,6 +75,33 @@ int BNGradCPUKernel::Run() { | |||
| return RET_OK; | |||
| } | |||
| int BNGradRun(void *cdata, int task_id) { | |||
| auto bn_kernel = reinterpret_cast<BNGradCPUKernel *>(cdata); | |||
| if (task_id == 0) { | |||
| auto error_code = bn_kernel->Execute(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "BNGradRun error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int BNGradCPUKernel::Run() { | |||
| // std::cout << "run succ" << std::endl; | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "BNGradCPUKernel Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuBNGradFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| @@ -27,18 +27,12 @@ class BNGradCPUKernel : public LiteKernel { | |||
| explicit BNGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), workspace(nullptr), workspace_size(0) {} | |||
| ~BNGradCPUKernel() override { | |||
| if (workspace) delete[] workspace; | |||
| } | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~BNGradCPUKernel() override {} | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| private: | |||
| float *workspace; | |||
| int workspace_size; | |||
| int Execute(int task_id); | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BN_GRAD_H_ | |||
| @@ -18,6 +18,7 @@ | |||
| #include "nnacl/fp32_grad/pack_ext.h" | |||
| #include "nnacl/fp32_grad/gemm.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::RET_ERROR; | |||
| @@ -25,6 +26,14 @@ using mindspore::lite::RET_OK; | |||
| namespace mindspore::kernel { | |||
| int ConvolutionTrainCPUKernel::Init() { | |||
| if (2 != in_tensors_.size()) { | |||
| MS_LOG(ERROR) << "Convolution should have two inputs"; | |||
| return RET_ERROR; | |||
| } | |||
| if (1 != out_tensors_.size()) { | |||
| MS_LOG(ERROR) << "Convolution should have one output"; | |||
| return RET_ERROR; | |||
| } | |||
| auto conv_param_ = reinterpret_cast<ConvParameter *>(op_parameter_); | |||
| auto *input_x = in_tensors_.at(kInputIndex); | |||
| auto *input_weight = in_tensors_.at(kWeightIndex); | |||
| @@ -46,22 +55,13 @@ int ConvolutionTrainCPUKernel::Init() { | |||
| int ws_size = conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->kernel_h_ * conv_param_->kernel_w_ * | |||
| conv_param_->input_channel_ / conv_param_->group_; | |||
| workspace = new (std::nothrow) float[ws_size]; | |||
| if (workspace == nullptr) { | |||
| MS_LOG(ERROR) << "new workspace fail!"; | |||
| return RET_ERROR; | |||
| } | |||
| SetWorkspaceSize(ws_size * sizeof(float)); | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionTrainCPUKernel::ReSize() { return RET_OK; } | |||
| int ConvolutionTrainCPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| int ConvolutionTrainCPUKernel::Execute(int task_id) { | |||
| auto conv_param_ = reinterpret_cast<ConvParameter *>(op_parameter_); | |||
| auto *input_x = in_tensors_.at(kInputIndex); | |||
| auto *input_w = in_tensors_.at(kWeightIndex); | |||
| @@ -86,6 +86,7 @@ int ConvolutionTrainCPUKernel::Run() { | |||
| int m = out_h * out_w; | |||
| int n = out_ch / groups; | |||
| int k = k_h * k_w * in_ch / groups; | |||
| float *workspace = static_cast<float *>(GetWorkspace()); | |||
| memset(y_addr, 0, out_y->Size()); | |||
| @@ -99,6 +100,31 @@ int ConvolutionTrainCPUKernel::Run() { | |||
| gemm(0, 1, m, n, k, 1, mat_a, k, mat_b, k, 1, mat_c, out_ch); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionTrainRun(void *cdata, int task_id) { | |||
| auto conv_kernel = reinterpret_cast<ConvolutionTrainCPUKernel *>(cdata); | |||
| auto error_code = conv_kernel->Execute(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvolutionTrainRun error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionTrainCPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvolutionTrainCPUKernel Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionTrainRun, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "conv train function error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -26,17 +26,13 @@ class ConvolutionTrainCPUKernel : public LiteKernel { | |||
| explicit ConvolutionTrainCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), workspace(nullptr) {} | |||
| ~ConvolutionTrainCPUKernel() override { | |||
| if (workspace) delete[] workspace; | |||
| } | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~ConvolutionTrainCPUKernel() override {} | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| private: | |||
| float *workspace; | |||
| int Execute(int task_id); | |||
| }; | |||
| kernel::LiteKernel *CpuConvTrainFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| @@ -20,6 +20,7 @@ | |||
| #include "nnacl/fp32_grad/pack_ext.h" | |||
| #include "nnacl/fp32_grad/gemm.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -50,26 +51,16 @@ int ConvolutionGradFilterCPUKernel::Init() { | |||
| conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; | |||
| conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; | |||
| int ws_size = conv_param->output_h_ * conv_param->output_w_ * conv_param->kernel_h_ * conv_param->kernel_w_ * | |||
| conv_param->input_channel_ / conv_param->group_; | |||
| workspace = new (std::nothrow) float[ws_size]; | |||
| if (workspace == nullptr) { | |||
| MS_LOG(ERROR) << "new workspace fail!"; | |||
| return RET_ERROR; | |||
| } | |||
| size_t ws_size = conv_param->output_h_ * conv_param->output_w_ * conv_param->kernel_h_ * conv_param->kernel_w_ * | |||
| conv_param->input_channel_ / conv_param->group_; | |||
| SetWorkspaceSize(ws_size * sizeof(float)); | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionGradFilterCPUKernel::ReSize() { return RET_OK; } | |||
| int ConvolutionGradFilterCPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| int ConvolutionGradFilterCPUKernel::Execute(int task_id) { | |||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_); | |||
| auto *input_dy = in_tensors_.at(0); | |||
| auto *input_x = in_tensors_.at(1); | |||
| @@ -84,8 +75,8 @@ int ConvolutionGradFilterCPUKernel::Run() { | |||
| int in_ch = conv_param->input_channel_; | |||
| int in_h = conv_param->input_h_; | |||
| int in_w = conv_param->input_w_; | |||
| int k_h = conv_param->kernel_h_; // out_dw->shape()[1]; | |||
| int k_w = conv_param->kernel_w_; // out_dw->shape()[2]; | |||
| int k_h = conv_param->kernel_h_; | |||
| int k_w = conv_param->kernel_w_; | |||
| int batch = conv_param->output_batch_; | |||
| int out_ch = conv_param->output_channel_; | |||
| int groups = conv_param->group_; | |||
| @@ -96,6 +87,8 @@ int ConvolutionGradFilterCPUKernel::Run() { | |||
| int n = k_h * k_w * in_ch / groups; | |||
| int k = out_ch / groups; | |||
| float *workspace = reinterpret_cast<float *>(GetWorkspace()); | |||
| // zero out pointer | |||
| memset(dw_addr, 0, out_dw->Size()); | |||
| @@ -104,15 +97,39 @@ int ConvolutionGradFilterCPUKernel::Run() { | |||
| float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups); | |||
| float *mat_b = workspace; | |||
| float *mat_c = dw_addr + j * nweights / groups; | |||
| float *im = x_addr + (i * groups) * (in_ch / groups) * in_h * in_w + j * (in_ch / groups); | |||
| float *im = x_addr + (i * in_ch * in_h * in_w) + j * (in_ch / groups); | |||
| im2row_hwc(im, mat_b, conv_param); | |||
| im2row_hwc(im, mat_b, conv_param, false); | |||
| gemm(1, 1, k, n, m, 1, mat_a, out_ch, mat_b, m, 1, mat_c, n); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionGradFilterRun(void *cdata, int task_id) { | |||
| auto convfilter_kernel = reinterpret_cast<ConvolutionGradFilterCPUKernel *>(cdata); | |||
| auto error_code = convfilter_kernel->Execute(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvolutionGradFilterRun error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionGradFilterCPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvolutionGradFilterCPUKernel Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradFilterRun, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "conv filter function error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuConvGradFilterFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::InnerContext *ctx, | |||
| @@ -26,17 +26,14 @@ class ConvolutionGradFilterCPUKernel : public LiteKernel { | |||
| explicit ConvolutionGradFilterCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), workspace(nullptr) {} | |||
| ~ConvolutionGradFilterCPUKernel() override { | |||
| if (workspace) delete[] workspace; | |||
| } | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~ConvolutionGradFilterCPUKernel() override {} | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| private: | |||
| float *workspace = nullptr; | |||
| int Execute(int task_id); | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -20,6 +20,7 @@ | |||
| #include "nnacl/fp32_grad/pack_ext.h" | |||
| #include "nnacl/fp32_grad/gemm.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -50,26 +51,16 @@ int ConvolutionGradInputCPUKernel::Init() { | |||
| conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; | |||
| conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; | |||
| int ws_size = conv_param->output_h_ * conv_param->output_w_ * conv_param->kernel_h_ * conv_param->kernel_w_ * | |||
| conv_param->input_channel_ / conv_param->group_; | |||
| size_t ws_size = conv_param->output_h_ * conv_param->output_w_ * conv_param->kernel_h_ * conv_param->kernel_w_ * | |||
| conv_param->input_channel_ / conv_param->group_; | |||
| workspace = new (std::nothrow) float[ws_size]; | |||
| if (workspace == nullptr) { | |||
| MS_LOG(ERROR) << "new workspace fail!"; | |||
| return RET_ERROR; | |||
| } | |||
| SetWorkspaceSize(ws_size * sizeof(float)); | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionGradInputCPUKernel::ReSize() { return 0; } | |||
| int ConvolutionGradInputCPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| int ConvolutionGradInputCPUKernel::ReSize() { return RET_OK; } | |||
| int ConvolutionGradInputCPUKernel::Execute(int task_id) { | |||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_); | |||
| auto *input_dy = in_tensors_.at(0); | |||
| auto *input_w = in_tensors_.at(1); | |||
| @@ -95,6 +86,7 @@ int ConvolutionGradInputCPUKernel::Run() { | |||
| int m = out_h * out_w; | |||
| int n = k_w * k_h * in_ch / groups; | |||
| int k = out_ch / groups; | |||
| float *workspace = reinterpret_cast<float *>(GetWorkspace()); | |||
| memset(dx_addr, 0, sizeof(float) * batch * in_ch * in_h * in_w); | |||
| @@ -107,6 +99,32 @@ int ConvolutionGradInputCPUKernel::Run() { | |||
| col2im_hwc(mat_c, dx_addr + (i * groups) * (in_ch / groups) * in_h * in_w + j * (in_ch / groups), conv_param); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionGradInputRun(void *cdata, int task_id) { | |||
| auto convinput_kernel = reinterpret_cast<ConvolutionGradInputCPUKernel *>(cdata); | |||
| auto error_code = convinput_kernel->Execute(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "conv input error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionGradInputCPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvolutionGradInputCPUKernel Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradInputRun, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "bias function error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -26,17 +26,13 @@ class ConvolutionGradInputCPUKernel : public LiteKernel { | |||
| explicit ConvolutionGradInputCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), workspace(nullptr) {} | |||
| ~ConvolutionGradInputCPUKernel() override { | |||
| if (workspace) delete[] workspace; | |||
| } | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~ConvolutionGradInputCPUKernel() override {} | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| private: | |||
| float *workspace; | |||
| int Execute(int task_id); | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,155 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp32_grad/deconvolution_grad_filter.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "nnacl/pack.h" | |||
| #include "nnacl/fp32_grad/pack_ext.h" | |||
| #include "nnacl/fp32_grad/gemm.h" | |||
| #include "include/errorcode.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_DeConv2DGradFilter; | |||
| namespace mindspore::kernel { | |||
| int DeConvolutionGradFilterCPUKernel::Init() { | |||
| // dy is in input 0 | |||
| // x is in input 1 | |||
| // dw is output 0 | |||
| auto *x_tensor = in_tensors_.at(1); | |||
| MS_ASSERT(x_tensor != nullptr); | |||
| auto *dy_tensor = in_tensors_.at(0); | |||
| MS_ASSERT(dy_tensor != nullptr); | |||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_); | |||
| conv_param->output_batch_ = dy_tensor->shape().at(kNHWC_N); | |||
| conv_param->input_batch_ = x_tensor->shape().at(kNHWC_N); | |||
| conv_param->input_h_ = x_tensor->shape().at(kNHWC_H); | |||
| conv_param->input_w_ = x_tensor->shape().at(kNHWC_W); | |||
| // assume OutCh|kh|kw|InCh | |||
| conv_param->input_channel_ = x_tensor->shape().at(kNHWC_C); | |||
| conv_param->output_channel_ = dy_tensor->shape().at(kNHWC_C); | |||
| conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; | |||
| conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; | |||
| int ws_size = conv_param->input_h_ * conv_param->input_w_ * conv_param->kernel_h_ * conv_param->kernel_w_ * | |||
| conv_param->output_channel_ / conv_param->group_; | |||
| SetWorkspaceSize(ws_size * sizeof(float)); | |||
| return RET_OK; | |||
| } | |||
| int DeConvolutionGradFilterCPUKernel::ReSize() { return RET_OK; } | |||
| int DeConvolutionGradFilterCPUKernel::Execute(int task_id) { | |||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_); | |||
| auto *input_dy = in_tensors_.at(0); | |||
| auto *input_x = in_tensors_.at(1); | |||
| auto *out_dw = out_tensors_.at(0); | |||
| auto x_addr = reinterpret_cast<float *>(input_x->MutableData()); | |||
| auto dy_addr = reinterpret_cast<float *>(input_dy->MutableData()); | |||
| auto dw_addr = reinterpret_cast<float *>(out_dw->MutableData()); | |||
| int i, j; | |||
| int in_ch = conv_param->input_channel_; | |||
| int in_h = conv_param->input_h_; | |||
| int in_w = conv_param->input_w_; | |||
| int k_h = conv_param->kernel_h_; | |||
| int k_w = conv_param->kernel_w_; | |||
| int batch = conv_param->output_batch_; | |||
| int out_ch = conv_param->output_channel_; | |||
| int groups = conv_param->group_; | |||
| int out_h = conv_param->output_h_; | |||
| int out_w = conv_param->output_w_; | |||
| int m = in_ch / groups; | |||
| int n = k_h * k_w * out_ch / groups; | |||
| int k = in_h * in_w; | |||
| float *workspace = reinterpret_cast<float *>(GetWorkspace()); | |||
| // zero out pointer | |||
| memset(dw_addr, 0, out_dw->Size()); | |||
| for (i = 0; i < batch; ++i) { | |||
| for (j = 0; j < groups; ++j) { | |||
| float *mat_a = x_addr + (i * (in_ch * in_h * in_w) + j * (in_ch / groups)); | |||
| float *mat_b = workspace; | |||
| float *mat_c = dw_addr + j * m; | |||
| float *im = dy_addr + (i * (out_h * out_w * out_ch) + j * (out_ch / groups)); | |||
| im2row_hwc(im, mat_b, conv_param, true); | |||
| gemm(0, 0, n, m, k, 1, mat_b, k, mat_a, in_ch, 1, mat_c, in_ch); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int DeConvolutionGradFilterRun(void *cdata, int task_id) { | |||
| auto convfilter_kernel = reinterpret_cast<DeConvolutionGradFilterCPUKernel *>(cdata); | |||
| auto error_code = convfilter_kernel->Execute(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "DeConvolutionGradFilterRun error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int DeConvolutionGradFilterCPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, DeConvolutionGradFilterRun, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "conv filter function error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuDeConvGradFilterFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::InnerContext *ctx, | |||
| const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2DGradFilter); | |||
| auto *kernel = new (std::nothrow) DeConvolutionGradFilterCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new kernel fail!"; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (RET_OK != ret) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DeConv2DGradFilter, CpuDeConvGradFilterFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_DECONVOLUTION_GRAD_FILTER_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_DECONVOLUTION_GRAD_FILTER_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| namespace mindspore::kernel { | |||
| class DeConvolutionGradFilterCPUKernel : public LiteKernel { | |||
| public: | |||
| explicit DeConvolutionGradFilterCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~DeConvolutionGradFilterCPUKernel() override {} | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int Execute(int task_id); | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_DECONVOLUTION_GRAD_FILTER_H_ | |||
| @@ -36,6 +36,7 @@ class MakeTupleCPUKernel : public LiteKernel { | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int DoActivation(int task_id); | |||
| private: | |||
| OpParameter *param; | |||
| @@ -20,6 +20,8 @@ | |||
| #include "nnacl/fp32/pooling.h" | |||
| #include "nnacl/fp32_grad/pooling_grad.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| // #include "src/train/ops/train_ops.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -60,12 +62,7 @@ int PoolingGradCPUKernel::Init() { | |||
| int PoolingGradCPUKernel::ReSize() { return RET_OK; } | |||
| int PoolingGradCPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| int PoolingGradCPUKernel::Execute(int task_id) { | |||
| PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(op_parameter_); | |||
| auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||
| auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||
| @@ -73,9 +70,41 @@ int PoolingGradCPUKernel::Run() { | |||
| if (pool_param->pool_mode_ == PoolMode_MaxPool) { | |||
| auto dx_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | |||
| auto dy_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | |||
| MaxPoolingGrad(input_ptr, dx_ptr, dy_ptr, output_ptr, pool_param); | |||
| MaxPoolingGrad(input_ptr, dx_ptr, dy_ptr, output_ptr, pool_param, task_id); | |||
| } else { | |||
| AvgPoolingGrad(input_ptr, output_ptr, pool_param); | |||
| AvgPoolingGrad(input_ptr, output_ptr, pool_param, task_id); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int PoolingGradImpl(void *cdata, int task_id) { | |||
| auto pooling = reinterpret_cast<PoolingGradCPUKernel *>(cdata); | |||
| auto error_code = pooling->Execute(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "Pooling Run error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int PoolingGradCPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "PoolingGradCPUKernel Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| // clear output buffer before parallel run | |||
| PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(op_parameter_); | |||
| auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||
| int size = | |||
| pooling_param->input_w_ * pooling_param->input_h_ * pooling_param->input_channel_ * pooling_param->output_batch_; | |||
| for (int i = 0; i < size; i++) output_ptr[i] = 0.0; | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, PoolingGradImpl, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -37,6 +37,9 @@ class PoolingGradCPUKernel : public LiteKernel { | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int Execute(int task_id); | |||
| private: | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -19,6 +19,7 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "nnacl/fp32/arithmetic.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| @@ -26,11 +27,21 @@ using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_PowerGrad; | |||
| namespace mindspore::kernel { | |||
| int PowerGradCPUKernel::Init() { return RET_OK; } | |||
| int PowerGradCPUKernel::Init() { | |||
| if (2 != in_tensors_.size()) { | |||
| MS_LOG(ERROR) << "Power Grad Filter should have 2 inputs"; | |||
| return RET_ERROR; | |||
| } | |||
| if (1 != out_tensors_.size()) { | |||
| MS_LOG(ERROR) << "Power Grad Filter should have one output"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int PowerGradCPUKernel::ReSize() { return RET_OK; } | |||
| int PowerGradCPUKernel::Run() { | |||
| int PowerGradCPUKernel::Execute(int task_id) { | |||
| auto dy_addr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||
| auto x_addr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | |||
| auto dx_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||
| @@ -47,6 +58,30 @@ int PowerGradCPUKernel::Run() { | |||
| return RET_OK; | |||
| } | |||
| int PowerGradRun(void *cdata, int task_id) { | |||
| auto power_kernel = reinterpret_cast<PowerGradCPUKernel *>(cdata); | |||
| auto error_code = power_kernel->Execute(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "power grad error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int PowerGradCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "PowerGradCPUKernel Prepare failed."; | |||
| return RET_ERROR; | |||
| } | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, PowerGradRun, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "power grad function error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuPowerGradFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| @@ -38,6 +38,7 @@ class PowerGradCPUKernel : public LiteKernel { | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int Execute(int task_id); | |||
| private: | |||
| float power_; | |||
| @@ -0,0 +1,121 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp32_grad/sgd.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/fp32/nchw2nhwc.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Sgd; | |||
| namespace mindspore::kernel { | |||
| int SgdCPUKernel::ReSize() { return RET_OK; } | |||
| int SgdCPUKernel::Execute(int task_id) { | |||
| auto weight = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | |||
| auto accumulate = reinterpret_cast<float *>(in_tensors_[3]->MutableData()); | |||
| float learning_rate = reinterpret_cast<float *>(in_tensors_[2]->MutableData())[0]; | |||
| auto gradient = reinterpret_cast<float *>(in_tensors_[1]->MutableData()); | |||
| float moment = reinterpret_cast<float *>(in_tensors_[4]->MutableData())[0]; | |||
| size_t elem_num = in_tensors_[0]->ElementsNum(); | |||
| if (sgd_param_->use_nesterov_) { | |||
| for (size_t i = 0; i < elem_num; ++i) { | |||
| accumulate[i] = accumulate[i] * moment + gradient[i]; | |||
| weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; | |||
| } | |||
| } else { | |||
| for (size_t i = 0; i < elem_num; ++i) { | |||
| accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - sgd_param_->dampening_); | |||
| weight[i] -= accumulate[i] * learning_rate; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int SgdRun(void *cdata, int task_id) { | |||
| auto Sgd_kernel = reinterpret_cast<SgdCPUKernel *>(cdata); | |||
| auto error_code = Sgd_kernel->Execute(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "SGD run error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int SgdCPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SgdCPUKernel Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, SgdRun, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "SGD function error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int SgdCPUKernel::Init() { | |||
| // Only for test with uninitialized Data | |||
| size_t elem_num = in_tensors_[0]->ElementsNum(); | |||
| auto accumulate = reinterpret_cast<float *>(in_tensors_[3]->MutableData()); | |||
| for (size_t i = 0; i < elem_num; i++) accumulate[i] = 0.0; | |||
| if (sgd_param_->dampening_ < 0.0f) { | |||
| MS_LOG(ERROR) << "dampening should be at least 0.0"; | |||
| return RET_ERROR; | |||
| } | |||
| if (sgd_param_->use_nesterov_ && sgd_param_->dampening_ > 0.0f) { | |||
| MS_LOG(ERROR) << "If use nesterov, dampening must equal to 0.0"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuSgdFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| const lite::PrimitiveC *primitive) { | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_Sgd); | |||
| auto *kernel = new (std::nothrow) SgdCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| MS_ASSERT(kernel != nullptr); | |||
| auto ret = kernel->Init(); | |||
| if (0 != ret) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sgd, CpuSgdFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_SGD_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_SGD_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "nnacl/fp32_grad/optimizer.h" | |||
| namespace mindspore::kernel { | |||
| class SgdCPUKernel : public LiteKernel { | |||
| public: | |||
| explicit SgdCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), sgd_param_(nullptr) { | |||
| sgd_param_ = reinterpret_cast<SgdParameter *>(parameter); | |||
| } | |||
| ~SgdCPUKernel() override {} | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int Execute(int task_id); | |||
| private: | |||
| SgdParameter *sgd_param_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_SGD_H_ | |||
| @@ -20,6 +20,7 @@ | |||
| #include "nnacl/fp32/softmax.h" | |||
| #include "src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| @@ -56,13 +57,8 @@ void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *lab | |||
| } | |||
| output2[0] = total_loss / param_->batch_size_; | |||
| } | |||
| int SoftmaxCrossEntropyWithLogitsCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare failed."; | |||
| return ret; | |||
| } | |||
| int SoftmaxCrossEntropyWithLogitsCPUKernel::Execute(int task_id) { | |||
| auto ins = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||
| auto labels = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | |||
| float *out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||
| @@ -75,6 +71,8 @@ int SoftmaxCrossEntropyWithLogitsCPUKernel::Run() { | |||
| MS_ASSERT(out != nullptr); | |||
| MS_ASSERT(labels != nullptr); | |||
| MS_ASSERT(ins != nullptr); | |||
| float *losses_ = static_cast<float *>(GetWorkspace()); | |||
| float *sum_data_ = losses_ + data_size; | |||
| std::fill(losses_, losses_ + data_size, 0); | |||
| std::fill(sum_data_, sum_data_ + sm_params_.input_shape_[0], 0); | |||
| Softmax(ins, losses_, sum_data_, &sm_params_); | |||
| @@ -82,6 +80,31 @@ int SoftmaxCrossEntropyWithLogitsCPUKernel::Run() { | |||
| return RET_OK; | |||
| } | |||
| int SoftmaxCrossEntropyWithLogitsRun(void *cdata, int task_id) { | |||
| auto softmax_kernel = reinterpret_cast<SoftmaxCrossEntropyWithLogitsCPUKernel *>(cdata); | |||
| auto error_code = softmax_kernel->Execute(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "SoftmaxCrossEntropy error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int SoftmaxCrossEntropyWithLogitsCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SoftmaxCrossEntropyWithLogitsCPUKernel Prepare failed."; | |||
| return ret; | |||
| } | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, SoftmaxCrossEntropyWithLogitsRun, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "SoftmaxCrossEntropy function error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int SoftmaxCrossEntropyWithLogitsCPUKernel::Init() { | |||
| auto dims = in_tensors_[0]->shape(); | |||
| param_->n_dim_ = 2; | |||
| @@ -99,18 +122,7 @@ int SoftmaxCrossEntropyWithLogitsCPUKernel::Init() { | |||
| } | |||
| size_t data_size = in_tensors_.at(0)->ElementsNum(); | |||
| losses_ = new (std::nothrow) float[data_size]; | |||
| if (losses_ == nullptr) { | |||
| MS_LOG(ERROR) << "failed to malloc losses!"; | |||
| return RET_ERROR; | |||
| } | |||
| sum_data_ = new (std::nothrow) float[dims[0]]; | |||
| if (sum_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "failed to malloc sum_data_!"; | |||
| return RET_ERROR; | |||
| } | |||
| SetWorkspaceSize((data_size + dims[0]) * sizeof(float)); | |||
| sm_params_.n_dim_ = 2; | |||
| sm_params_.element_size_ = data_size; | |||
| sm_params_.axis_ = 1; | |||
| @@ -138,5 +150,4 @@ kernel::LiteKernel *CpuSoftmaxCrossEntropyFp32KernelCreator(const std::vector<li | |||
| } | |||
| return kernel; | |||
| } | |||
| // REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SoftmaxCrossEntropy, CpuSoftmaxCrossEntropyFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -31,27 +31,21 @@ class SoftmaxCrossEntropyWithLogitsCPUKernel : public LossKernel { | |||
| const std::vector<lite::Tensor *> &outputs, | |||
| const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LossKernel(parameter, inputs, outputs, ctx, primitive), losses_(nullptr), sum_data_(nullptr) { | |||
| : LossKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| param_ = reinterpret_cast<SoftmaxCrossEntropyParameter *>(parameter); | |||
| } | |||
| ~SoftmaxCrossEntropyWithLogitsCPUKernel() override { | |||
| if (losses_) delete[] losses_; | |||
| if (sum_data_) delete[] sum_data_; | |||
| } | |||
| ~SoftmaxCrossEntropyWithLogitsCPUKernel() override {} | |||
| void ForwardPostExecute(const float *labels, const float *logits, float *output1, float *output2) const; | |||
| // void ForwardPostExecute(const int *labels, const float *losses, float *output) const; | |||
| // void GradPostExecute(const int *labels, const float *losses, float* grads, float *output) const; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int Execute(int task_id); | |||
| private: | |||
| SoftmaxCrossEntropyParameter *param_; | |||
| SoftmaxParameter sm_params_; | |||
| float *losses_ = nullptr; | |||
| float *sum_data_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -20,6 +20,7 @@ | |||
| #include "nnacl/fp32_grad/softmax_grad.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "include/errorcode.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -46,33 +47,49 @@ int SoftmaxGradCPUKernel::Init() { | |||
| axis = param->axis_ = (in_dims - 1); | |||
| } | |||
| int inner_size = 1; | |||
| inner_size_ = 1; | |||
| for (size_t i = axis + 1; i < in_dims; i++) { | |||
| inner_size *= in_shape[i]; | |||
| inner_size_ *= in_shape[i]; | |||
| } | |||
| sum_data_ = new (std::nothrow) float[inner_size]; | |||
| if (sum_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "failed to malloc sum_data_!"; | |||
| return RET_ERROR; | |||
| } | |||
| sum_mul_ = new (std::nothrow) float[inner_size * in_shape[axis]]; | |||
| if (sum_mul_ == nullptr) { | |||
| MS_LOG(ERROR) << "failed to malloc sum_mul_!"; | |||
| return RET_ERROR; | |||
| } | |||
| SetWorkspaceSize(inner_size_ * (1 + in_shape[axis]) * sizeof(float)); | |||
| return RET_OK; | |||
| } | |||
| int SoftmaxGradCPUKernel::ReSize() { return RET_OK; } | |||
| int SoftmaxGradCPUKernel::Run() { | |||
| int SoftmaxGradCPUKernel::Execute(int task_id) { | |||
| auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->MutableData()); | |||
| auto yt_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | |||
| auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData()); | |||
| float *sum_data_ = static_cast<float *>(GetWorkspace()); | |||
| float *sum_mul_ = sum_data_ + inner_size_; | |||
| SoftmaxGrad(input_ptr, yt_ptr, output_ptr, sum_data_, sum_mul_, reinterpret_cast<SoftmaxParameter *>(op_parameter_)); | |||
| return RET_OK; | |||
| } | |||
| int SoftmaxGradRun(void *cdata, int task_id) { | |||
| auto softmax_kernel = reinterpret_cast<SoftmaxGradCPUKernel *>(cdata); | |||
| auto error_code = softmax_kernel->Execute(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "softmax_kernel SoftmaxGradRun task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int SoftmaxGradCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SoftmaxGradCPUKernel Prepare failed."; | |||
| return ret; | |||
| } | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, SoftmaxGradRun, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "SoftmaxGradRun function error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -27,21 +27,18 @@ class SoftmaxGradCPUKernel : public LiteKernel { | |||
| explicit SoftmaxGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), sum_data_(nullptr), sum_mul_(nullptr) { | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| param = reinterpret_cast<SoftmaxParameter *>(parameter); | |||
| } | |||
| ~SoftmaxGradCPUKernel() override { | |||
| if (sum_data_) delete[] sum_data_; | |||
| if (sum_mul_) delete[] sum_mul_; | |||
| } | |||
| ~SoftmaxGradCPUKernel() override {} | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int Execute(int task_id); | |||
| private: | |||
| SoftmaxParameter *param; | |||
| float *sum_data_ = nullptr; | |||
| float *sum_mul_ = nullptr; | |||
| size_t inner_size_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -20,6 +20,7 @@ | |||
| #include "nnacl/fp32/softmax.h" | |||
| #include "src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| @@ -80,13 +81,7 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *lab | |||
| return RET_OK; | |||
| } | |||
| int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare failed."; | |||
| return ret; | |||
| } | |||
| int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Execute(int task_id) { | |||
| auto ins = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||
| auto labels = reinterpret_cast<int *>(in_tensors_.at(1)->MutableData()); | |||
| float *out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||
| @@ -98,8 +93,11 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { | |||
| MS_ASSERT(out != nullptr); | |||
| MS_ASSERT(labels != nullptr); | |||
| MS_ASSERT(ins != nullptr); | |||
| std::fill(losses_, losses_ + data_size, 0); | |||
| std::fill(sum_data_, sum_data_ + sm_params_.input_shape_[0], 0); | |||
| float *losses_ = static_cast<float *>(GetWorkspace()); | |||
| float *sum_data_ = losses_ + data_size; | |||
| std::fill(losses_, losses_ + data_size, 0.f); | |||
| std::fill(sum_data_, sum_data_ + sm_params_.input_shape_[0], 0.f); | |||
| Softmax(ins, losses_, sum_data_, &sm_params_); | |||
| if (is_train()) { | |||
| GradPostExecute(labels, losses_, grads, out); | |||
| @@ -109,6 +107,30 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { | |||
| return RET_OK; | |||
| } | |||
| int SparseSoftmaxCrossEntropyRun(void *cdata, int task_id) { | |||
| auto sparse_kernel = reinterpret_cast<SparseSoftmaxCrossEntropyWithLogitsCPUKernel *>(cdata); | |||
| auto error_code = sparse_kernel->Execute(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "SparseSoftmaxCrossEntropyRun error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SparseSoftmaxCrossEntropyWithLogitsCPUKernel Prepare failed."; | |||
| return ret; | |||
| } | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, SparseSoftmaxCrossEntropyRun, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "SparseSoftmaxCrossEntropy function error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Init() { | |||
| auto dims = in_tensors_[0]->shape(); | |||
| param->n_dim_ = 2; | |||
| @@ -125,18 +147,7 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Init() { | |||
| return RET_ERROR; | |||
| } | |||
| size_t data_size = in_tensors_.at(0)->ElementsNum(); | |||
| losses_ = new (std::nothrow) float[data_size]; | |||
| if (losses_ == nullptr) { | |||
| MS_LOG(ERROR) << "failed to malloc losses!"; | |||
| return RET_ERROR; | |||
| } | |||
| sum_data_ = new (std::nothrow) float[dims[0]]; | |||
| if (sum_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "failed to malloc sum_data_!"; | |||
| return RET_ERROR; | |||
| } | |||
| SetWorkspaceSize((data_size + dims[0]) * sizeof(float)); | |||
| sm_params_.n_dim_ = 2; | |||
| sm_params_.element_size_ = data_size; | |||
| sm_params_.axis_ = 1; | |||
| @@ -32,13 +32,10 @@ class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LossKernel { | |||
| const std::vector<lite::Tensor *> &outputs, | |||
| const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LossKernel(parameter, inputs, outputs, ctx, primitive), losses_(nullptr), sum_data_(nullptr) { | |||
| : LossKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| param = reinterpret_cast<SoftmaxCrossEntropyParameter *>(parameter); | |||
| } | |||
| ~SparseSoftmaxCrossEntropyWithLogitsCPUKernel() override { | |||
| if (losses_) delete[] losses_; | |||
| if (sum_data_) delete[] sum_data_; | |||
| } | |||
| ~SparseSoftmaxCrossEntropyWithLogitsCPUKernel() override {} | |||
| int ForwardPostExecute(const int *labels, const float *losses, float *output) const; | |||
| int GradPostExecute(const int *labels, const float *losses, float *grads, float *output) const; | |||
| @@ -46,12 +43,11 @@ class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LossKernel { | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int Execute(int task_id); | |||
| private: | |||
| SoftmaxCrossEntropyParameter *param; | |||
| SoftmaxParameter sm_params_; | |||
| float *losses_ = nullptr; | |||
| float *sum_data_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -19,6 +19,7 @@ | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -28,16 +29,21 @@ using mindspore::schema::PrimitiveType_TupleGetItem; | |||
| namespace mindspore::kernel { | |||
| int TupleGetItemCPUKernel::Init() { return RET_OK; } | |||
| int TupleGetItemCPUKernel::ReSize() { return 0; } | |||
| int TupleGetItemCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare failed."; | |||
| int TupleGetItemCPUKernel::Init() { | |||
| if (1 != in_tensors_.size()) { | |||
| MS_LOG(ERROR) << "Tuple Grad Filter should have one input"; | |||
| return RET_ERROR; | |||
| } | |||
| if (1 != out_tensors_.size()) { | |||
| MS_LOG(ERROR) << "Tuple Grad Filter should have one output"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int TupleGetItemCPUKernel::ReSize() { return RET_OK; } | |||
| int TupleGetItemCPUKernel::Execute(int task_id) { | |||
| auto in = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||
| auto out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||
| @@ -46,6 +52,30 @@ int TupleGetItemCPUKernel::Run() { | |||
| return RET_OK; | |||
| } | |||
| int TupleRun(void *cdata, int task_id) { | |||
| auto tuple_kernel = reinterpret_cast<TupleGetItemCPUKernel *>(cdata); | |||
| auto error_code = tuple_kernel->Execute(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "tuple grad error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int TupleGetItemCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "TupleGetItemCPUKernel Prepare failed."; | |||
| return RET_ERROR; | |||
| } | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, TupleRun, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "tuple function error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuTupleGetItemFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::InnerContext *ctx, | |||
| @@ -35,6 +35,7 @@ class TupleGetItemCPUKernel : public LiteKernel { | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int Execute(int task_id); | |||
| private: | |||
| OpParameter *param; | |||
| @@ -29,6 +29,11 @@ | |||
| #include "nnacl/power_parameter.h" | |||
| #include "src/ops/bias_grad.h" | |||
| #include "nnacl/arithmetic_common.h" | |||
| #include "nnacl/fp32_grad/optimizer.h" | |||
| #include "src/ops/apply_momentum.h" | |||
| #include "src/ops/sgd.h" | |||
| #include "src/ops/bn_grad.h" | |||
| #include "nnacl/fp32_grad/batch_norm.h" | |||
| namespace mindspore::kernel { | |||
| @@ -48,6 +53,49 @@ OpParameter *DefaultPopulateParameter(const mindspore::lite::PrimitiveC *primiti | |||
| return param; | |||
| } | |||
| OpParameter *PopulateApplyMomentumParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; | |||
| return nullptr; | |||
| } | |||
| ApplyMomentumParameter *p = reinterpret_cast<ApplyMomentumParameter *>(malloc(sizeof(ApplyMomentumParameter))); | |||
| if (p == nullptr) { | |||
| MS_LOG(ERROR) << "new ApplyMomentumParameter failed."; | |||
| return nullptr; | |||
| } | |||
| p->op_parameter_.type_ = primitive->Type(); | |||
| auto apply_momentum_primitive = | |||
| reinterpret_cast<mindspore::lite::ApplyMomentum *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||
| p->grad_scale_ = apply_momentum_primitive->GetGradientScale(); | |||
| p->use_locking_ = apply_momentum_primitive->GetUseLocking(); | |||
| p->use_nesterov_ = apply_momentum_primitive->GetUseNesterov(); | |||
| return reinterpret_cast<OpParameter *>(p); | |||
| } | |||
| OpParameter *PopulateSgdParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; | |||
| return nullptr; | |||
| } | |||
| SgdParameter *p = reinterpret_cast<SgdParameter *>(malloc(sizeof(SgdParameter))); | |||
| if (p == nullptr) { | |||
| MS_LOG(ERROR) << "new SgdParameter failed."; | |||
| return nullptr; | |||
| } | |||
| p->op_parameter_.type_ = primitive->Type(); | |||
| auto sgd_primitive = reinterpret_cast<mindspore::lite::Sgd *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||
| p->weight_decay_ = sgd_primitive->GetWeightDecay(); | |||
| p->dampening_ = sgd_primitive->GetDampening(); | |||
| p->use_nesterov_ = sgd_primitive->GetUseNesterov(); | |||
| return reinterpret_cast<OpParameter *>(p); | |||
| } | |||
| OpParameter *PopulateSoftmaxCrossEntropyParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; | |||
| @@ -250,9 +298,27 @@ OpParameter *PopulateBiasGradParameter(const mindspore::lite::PrimitiveC *primit | |||
| return reinterpret_cast<OpParameter *>(arithmetic_param); | |||
| } | |||
| OpParameter *PopulateBNGradParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; | |||
| return nullptr; | |||
| } | |||
| BNGradParameter *bnGrad_param = reinterpret_cast<BNGradParameter *>(malloc(sizeof(BNGradParameter))); | |||
| if (bnGrad_param == nullptr) { | |||
| MS_LOG(ERROR) << "new BNGradParameter failed."; | |||
| return nullptr; | |||
| } | |||
| bnGrad_param->op_parameter_.type_ = primitive->Type(); | |||
| auto bngrad = reinterpret_cast<mindspore::lite::BNGrad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||
| bnGrad_param->epsilon_ = bngrad->GetEps(); | |||
| bnGrad_param->momentum_ = 0.1; | |||
| return reinterpret_cast<OpParameter *>(bnGrad_param); | |||
| } | |||
| void PopulateTrainParameters() { | |||
| auto ppr = PopulateParameterRegistry::GetInstance(); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_ApplyMomentum, DefaultPopulateParameter); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_ApplyMomentum, PopulateApplyMomentumParameter); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_BiasGrad, PopulateBiasGradParameter); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_SoftmaxCrossEntropy, PopulateSoftmaxCrossEntropyParameter); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_ActivationGrad, PopulateActivationGradParameter); | |||
| @@ -263,6 +329,8 @@ void PopulateTrainParameters() { | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_Conv2DGradInput, PopulateConvolutionGradInputParameter); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_PoolingGrad, PopulatePoolingGradParameter); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_PowerGrad, PopulatePowerGradParameter); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_Sgd, PopulateSgdParameter); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_BNGrad, PopulateBNGradParameter); | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -14,12 +14,12 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "include/train_session.h" | |||
| #include "src/train/train_session.h" | |||
| #include <algorithm> | |||
| #include "src/common/log_adapter.h" | |||
| #include "include/context.h" | |||
| #include "include/train_model.h" | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "include/errorcode.h" | |||
| #include "include/train_model.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/tensor.h" | |||
| #include "src/train/loss_kernel.h" | |||
| @@ -29,7 +29,8 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/arm/fp32_grad/convolution.h" | |||
| namespace mindspore::session { | |||
| namespace mindspore { | |||
| namespace lite { | |||
| static size_t TSFindTensor(const std::vector<lite::Tensor *> &where, const lite::Tensor *searchParameter) { | |||
| for (size_t i = 0; i < where.size(); i++) { | |||
| @@ -42,45 +43,72 @@ static size_t TSFindTensor(const std::vector<lite::Tensor *> &where, const lite: | |||
| TrainSession::TrainSession() { kernel::PopulateTrainParameters(); } | |||
| void TrainSession::ReplaceOps() { | |||
| mindspore::lite::KernelRegistrar tmp(mindspore::kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, | |||
| mindspore::schema::PrimitiveType_Conv2D, | |||
| mindspore::kernel::CpuConvTrainFp32KernelCreator); | |||
| std::vector<CreatorOp> TrainSession::ReplaceOps() { | |||
| const std::vector<CreatorOp> replace = { | |||
| {{mindspore::kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, mindspore::schema::PrimitiveType_Conv2D}, | |||
| mindspore::kernel::CpuConvTrainFp32KernelCreator}, | |||
| {{mindspore::kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, mindspore::schema::PrimitiveType_DepthwiseConv2D}, | |||
| mindspore::kernel::CpuConvTrainFp32KernelCreator}}; | |||
| mindspore::lite::KernelRegistry *reg = mindspore::lite::KernelRegistry::GetInstance(); | |||
| std::vector<CreatorOp> results; | |||
| for (auto v : replace) { | |||
| const CreatorOp cl = make_tuple(std::get<0>(v), reg->GetCreator(std::get<0>(v))); | |||
| results.push_back(cl); | |||
| reg->RegKernel(std::get<0>(v), std::get<1>(v)); | |||
| } | |||
| return results; | |||
| } | |||
| mindspore::lite::KernelRegistrar tmp0(mindspore::kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, | |||
| mindspore::schema::PrimitiveType_DepthwiseConv2D, | |||
| mindspore::kernel::CpuConvTrainFp32KernelCreator); | |||
| void TrainSession::RestoreOps(const std::vector<CreatorOp> &restore) { | |||
| mindspore::lite::KernelRegistry *reg = mindspore::lite::KernelRegistry::GetInstance(); | |||
| for (auto v : restore) { | |||
| reg->RegKernel(std::get<0>(v), std::get<1>(v)); | |||
| } | |||
| } | |||
| int TrainSession::CompileGraph(lite::Model *model) { | |||
| model_ = reinterpret_cast<lite::TrainModel *>(model); | |||
| if (model_ == nullptr) { | |||
| MS_LOG(ERROR) << "TrainSession can only compile TrainModels"; | |||
| return lite::RET_ERROR; | |||
| void TrainSession::AllocWorkSpace() { | |||
| size_t workspace_size = 0; | |||
| for (auto k : kernels_) { | |||
| if (workspace_size < k->GetWorkspaceSize()) { | |||
| workspace_size = k->GetWorkspaceSize(); | |||
| } | |||
| } | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(workspace_size); | |||
| } | |||
| int TrainSession::CompileGraph(lite::Model *model) { return lite::RET_ERROR; } | |||
| ReplaceOps(); | |||
| auto ret = LiteSession::CompileGraph(model); | |||
| int TrainSession::CompileTrainGraph(mindspore::lite::TrainModel *model) { | |||
| model_ = model; | |||
| auto restore = ReplaceOps(); | |||
| auto ret = lite::LiteSession::CompileGraph(model); | |||
| orig_output_map_ = output_node_map_; | |||
| orig_output_tensor_map_ = output_tensor_map_; | |||
| for (auto inTensor : inputs_) inTensor->MutableData(); | |||
| RestoreOps(restore); | |||
| AllocWorkSpace(); | |||
| return ret; | |||
| } | |||
| TrainSession::~TrainSession() { delete model_; } | |||
| TrainSession::~TrainSession() { | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| delete model_; | |||
| } | |||
| void *TrainSession::ExportToBuf(char *buf, size_t *len) const { return model_->ExportBuf(buf, len); } | |||
| int TrainSession::RunGraph(const session::KernelCallBack &before, const session::KernelCallBack &after) { | |||
| this->outputs_.clear(); | |||
| for (auto ms_tensors : output_node_map_) | |||
| for (auto ms_tensor : ms_tensors.second) this->outputs_.push_back((reinterpret_cast<lite::Tensor *>(ms_tensor))); | |||
| if (train_mode_) return LiteSession::RunGraph(before, after); | |||
| for (auto ms_tensor : ms_tensors.second) this->outputs_.push_back((static_cast<lite::Tensor *>(ms_tensor))); | |||
| if (train_mode_) return lite::LiteSession::RunGraph(before, after); | |||
| // object is expected to run only inference part of graph | |||
| // prepare a list of kernels till the loss function -- temporary solution | |||
| std::vector<kernel::LiteKernel *> inference_kernels; | |||
| for (auto kernel : this->kernels_) { | |||
| if (reinterpret_cast<const kernel::LossKernel *>(kernel) != nullptr) break; | |||
| if (IsLossKernel(kernel)) break; | |||
| inference_kernels.push_back(kernel); | |||
| } | |||
| @@ -106,9 +134,10 @@ void TrainSession::Train() { | |||
| output_tensor_map_.clear(); | |||
| train_mode_ = true; | |||
| for (auto kernel : this->kernels_) { | |||
| if (reinterpret_cast<const kernel::LossKernel *>(kernel) != nullptr) { | |||
| if (IsLossKernel(kernel)) { | |||
| auto *ms_tensor = kernel->out_tensors().at(0); | |||
| if (ms_tensor != nullptr) { | |||
| ms_tensor->MutableData(); | |||
| output_node_map_[kernel->name()].emplace_back(ms_tensor); | |||
| auto index = TSFindTensor(tensors_, ms_tensor); | |||
| if (index != tensors_.size()) { | |||
| @@ -124,26 +153,43 @@ void TrainSession::Eval() { | |||
| MS_ASSERT(nullptr != kernel); | |||
| kernel->eval(); | |||
| } | |||
| kernel::LiteKernel *last_kernel = nullptr; | |||
| output_node_map_ = orig_output_map_; | |||
| output_tensor_map_ = orig_output_tensor_map_; | |||
| train_mode_ = false; | |||
| for (auto kernel : this->kernels_) { | |||
| if ((reinterpret_cast<const kernel::LossKernel *>(kernel) != nullptr) && (last_kernel != nullptr)) { | |||
| if (output_node_map_.find(last_kernel->name()) == output_node_map_.end()) { | |||
| auto *ms_tensor = last_kernel->out_tensors().at(0); | |||
| if (ms_tensor != nullptr) { | |||
| output_node_map_[last_kernel->name()].emplace_back(ms_tensor); | |||
| auto index = TSFindTensor(tensors_, ms_tensor); | |||
| if (index != tensors_.size()) { | |||
| output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor)); | |||
| if (IsLossKernel(kernel)) { | |||
| for (auto in_kernel : kernel->in_kernels()) { | |||
| if (output_node_map_.find(in_kernel->name()) == output_node_map_.end()) { | |||
| auto *ms_tensor = in_kernel->out_tensors().at(0); | |||
| if (ms_tensor != nullptr) { | |||
| output_node_map_[in_kernel->name()].emplace_back(ms_tensor); | |||
| auto index = TSFindTensor(tensors_, ms_tensor); | |||
| if (index != tensors_.size()) { | |||
| output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| last_kernel = kernel; | |||
| } | |||
| } | |||
| } // namespace mindspore::session | |||
| bool TrainSession::IsLossKernel(kernel::LiteKernel *kernel) { | |||
| return (kernel->Type() == schema::PrimitiveType_SoftmaxCrossEntropy); | |||
| } | |||
| } // namespace lite | |||
| session::TrainSession *session::TrainSession::CreateSession(lite::Context *context) { | |||
| auto session = new lite::TrainSession(); | |||
| auto ret = session->Init(context); | |||
| if (ret != mindspore::lite::RET_OK) { | |||
| MS_LOG(ERROR) << "init sesssion failed"; | |||
| delete session; | |||
| return nullptr; | |||
| } | |||
| return session; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,94 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ | |||
| #define MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <tuple> | |||
| #include <unordered_map> | |||
| #include "src/ops/primitive_c.h" | |||
| #include "include/train_session.h" | |||
| #include "include/train_model.h" | |||
| #include "src/lite_session.h" | |||
| /* | |||
| Inheritance Diagram | |||
| +-------------------------------+ | |||
| | session::LiteSession | | |||
| +--------+------------+---------+ | |||
| / \ | |||
| +-----------------+-----+ +-------+------------+ | |||
| | session::TrainSession | | lite::LiteSession | | |||
| +-----------------+-----+ +-------+------------+ | |||
| \ / | |||
| +--------+------------+---------+ | |||
| | lite::TrainSession | | |||
| +-------------------------------+ | |||
| */ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| using CreatorOp = std::tuple<mindspore::kernel::KernelKey, mindspore::kernel::KernelCreator>; | |||
| class TrainSession : virtual public session::TrainSession, virtual public lite::LiteSession { | |||
| public: | |||
| TrainSession(); | |||
| ~TrainSession(); | |||
| int RunGraph(const session::KernelCallBack &before = nullptr, | |||
| const session::KernelCallBack &after = nullptr) override; | |||
| int CompileGraph(lite::Model *model) override; | |||
| int CompileTrainGraph(lite::TrainModel *model) override; | |||
| void *ExportToBuf(char *buf, size_t *len) const override; | |||
| void Train() override; | |||
| void Eval() override; | |||
| void BindThread(bool if_bind) override { return lite::LiteSession::BindThread(if_bind); } | |||
| std::vector<tensor::MSTensor *> GetInputs() const override { return lite::LiteSession::GetInputs(); } | |||
| mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &tensor_name) const override { | |||
| return lite::LiteSession::GetInputsByTensorName(tensor_name); | |||
| } | |||
| std::vector<tensor::MSTensor *> GetOutputsByNodeName(const std::string &node_name) const override { | |||
| return lite::LiteSession::GetOutputsByNodeName(node_name); | |||
| } | |||
| std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetOutputs() const override { | |||
| return lite::LiteSession::GetOutputs(); | |||
| } | |||
| std::vector<std::string> GetOutputTensorNames() const override { return lite::LiteSession::GetOutputTensorNames(); } | |||
| mindspore::tensor::MSTensor *GetOutputByTensorName(const std::string &tensor_name) const override { | |||
| return lite::LiteSession::GetOutputByTensorName(tensor_name); | |||
| } | |||
| int Resize(const std::vector<tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims) override { | |||
| return lite::LiteSession::Resize(inputs, dims); | |||
| } | |||
| protected: | |||
| void AllocWorkSpace(); | |||
| virtual std::vector<CreatorOp> ReplaceOps(); | |||
| virtual void RestoreOps(const std::vector<CreatorOp> &restore); | |||
| bool IsLossKernel(kernel::LiteKernel *kernel); | |||
| TrainModel *model_ = nullptr; | |||
| std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> orig_output_map_; | |||
| std::unordered_map<std::string, mindspore::tensor::MSTensor *> orig_output_tensor_map_; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ | |||
| @@ -112,9 +112,13 @@ TEST_F(TestArithmeticGradFp32, TestAddGradFp32) { | |||
| std::vector<lite::Tensor *> outputs = {all_tensors[3], all_tensors[4]}; | |||
| auto param = PopulateArithmeticParameter(schema::PrimitiveType_AddGrad, inputs, outputs); | |||
| lite::InnerContext ctx; | |||
| ctx.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_AddGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), &ctx, desc, nullptr); | |||
| kernel_obj->Run(); | |||
| float *output_ptr = reinterpret_cast<float *>(outputs[1]->MutableData()); | |||
| @@ -146,9 +150,13 @@ TEST_F(TestArithmeticGradFp32, TestAddGrad2Fp32) { | |||
| std::vector<lite::Tensor *> outputs = {all_tensors[4], all_tensors[3]}; | |||
| auto param = PopulateArithmeticParameter(schema::PrimitiveType_AddGrad, inputs, outputs); | |||
| lite::InnerContext ctx; | |||
| ctx.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_AddGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), &ctx, desc, nullptr); | |||
| kernel_obj->Run(); | |||
| float *output_ptr = reinterpret_cast<float *>(outputs[0]->MutableData()); | |||
| @@ -182,9 +190,13 @@ TEST_F(TestArithmeticGradFp32, TestAddGrad3Fp32) { | |||
| std::vector<lite::Tensor *> outputs = {all_tensors[3], all_tensors[4]}; | |||
| auto param = PopulateArithmeticParameter(schema::PrimitiveType_AddGrad, inputs, outputs); | |||
| lite::InnerContext ctx; | |||
| ctx.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_AddGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), &ctx, desc, nullptr); | |||
| kernel_obj->Run(); | |||
| float *output_ptr = reinterpret_cast<float *>(outputs[0]->MutableData()); | |||
| @@ -219,9 +231,13 @@ TEST_F(TestArithmeticGradFp32, TestSubGradFp32) { | |||
| std::vector<lite::Tensor *> outputs = {all_tensors[3], all_tensors[4]}; | |||
| auto param = PopulateArithmeticParameter(schema::PrimitiveType_SubGrad, inputs, outputs); | |||
| lite::InnerContext ctx; | |||
| ctx.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SubGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), &ctx, desc, nullptr); | |||
| kernel_obj->Run(); | |||
| float *output_ptr = reinterpret_cast<float *>(outputs[1]->MutableData()); | |||
| @@ -256,9 +272,13 @@ TEST_F(TestArithmeticGradFp32, TestSubGrad2Fp32) { | |||
| std::vector<lite::Tensor *> outputs = {all_tensors[4], all_tensors[3]}; | |||
| auto param = PopulateArithmeticParameter(schema::PrimitiveType_SubGrad, inputs, outputs); | |||
| lite::InnerContext ctx; | |||
| ctx.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SubGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), &ctx, desc, nullptr); | |||
| kernel_obj->Run(); | |||
| float *output_ptr = reinterpret_cast<float *>(outputs[0]->MutableData()); | |||
| @@ -291,9 +311,13 @@ TEST_F(TestArithmeticGradFp32, TestMulGradFp32) { | |||
| std::vector<lite::Tensor *> outputs = {all_tensors[3], all_tensors[4]}; | |||
| auto param = PopulateArithmeticParameter(schema::PrimitiveType_MulGrad, inputs, outputs); | |||
| lite::InnerContext ctx; | |||
| ctx.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_MulGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), &ctx, desc, nullptr); | |||
| int loop_count = 1000; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| @@ -336,9 +360,13 @@ TEST_F(TestArithmeticGradFp32, TestMulGrad2Fp32) { | |||
| std::vector<lite::Tensor *> outputs = {all_tensors[4], all_tensors[3]}; | |||
| auto param = PopulateArithmeticParameter(schema::PrimitiveType_MulGrad, inputs, outputs); | |||
| lite::InnerContext ctx; | |||
| ctx.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_MulGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), &ctx, desc, nullptr); | |||
| kernel_obj->Run(); | |||
| float *output_ptr = reinterpret_cast<float *>(outputs[0]->MutableData()); | |||
| @@ -372,9 +400,13 @@ TEST_F(TestArithmeticGradFp32, TestMulGrad3Fp32) { | |||
| std::vector<lite::Tensor *> outputs = {all_tensors[3], all_tensors[4]}; | |||
| auto param = PopulateArithmeticParameter(schema::PrimitiveType_MulGrad, inputs, outputs); | |||
| lite::InnerContext ctx; | |||
| ctx.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_MulGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), &ctx, desc, nullptr); | |||
| kernel_obj->Run(); | |||
| float *output_ptr = reinterpret_cast<float *>(outputs[1]->MutableData()); | |||
| @@ -408,9 +440,13 @@ TEST_F(TestArithmeticGradFp32, TestMulGrad4Fp32) { | |||
| std::vector<lite::Tensor *> outputs = {all_tensors[4], all_tensors[3]}; | |||
| auto param = PopulateArithmeticParameter(schema::PrimitiveType_MulGrad, inputs, outputs); | |||
| lite::InnerContext ctx; | |||
| ctx.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_MulGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), &ctx, desc, nullptr); | |||
| kernel_obj->Run(); | |||
| float *output_ptr = reinterpret_cast<float *>(outputs[0]->MutableData()); | |||
| @@ -444,9 +480,13 @@ TEST_F(TestArithmeticGradFp32, TestDivGradFp32) { | |||
| std::vector<lite::Tensor *> outputs = {all_tensors[3], all_tensors[4]}; | |||
| auto param = PopulateArithmeticParameter(schema::PrimitiveType_DivGrad, inputs, outputs); | |||
| lite::InnerContext ctx; | |||
| ctx.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DivGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), &ctx, desc, nullptr); | |||
| kernel_obj->Run(); | |||
| float *output_ptr = reinterpret_cast<float *>(outputs[1]->MutableData()); | |||
| @@ -480,9 +520,13 @@ TEST_F(TestArithmeticGradFp32, TestDivGrad2Fp32) { | |||
| std::vector<lite::Tensor *> outputs = {all_tensors[4], all_tensors[3]}; | |||
| auto param = PopulateArithmeticParameter(schema::PrimitiveType_DivGrad, inputs, outputs); | |||
| lite::InnerContext ctx; | |||
| ctx.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DivGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), &ctx, desc, nullptr); | |||
| kernel_obj->Run(); | |||
| float *output_ptr = reinterpret_cast<float *>(outputs[0]->MutableData()); | |||
| @@ -517,9 +561,13 @@ TEST_F(TestArithmeticGradFp32, TestDivGrad3Fp32) { | |||
| std::vector<lite::Tensor *> outputs = {all_tensors[3], all_tensors[4]}; | |||
| auto param = PopulateArithmeticParameter(schema::PrimitiveType_DivGrad, inputs, outputs); | |||
| lite::InnerContext ctx; | |||
| ctx.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DivGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), &ctx, desc, nullptr); | |||
| kernel_obj->Run(); | |||
| float *output_ptr = reinterpret_cast<float *>(outputs[1]->MutableData()); | |||
| @@ -553,9 +601,13 @@ TEST_F(TestArithmeticGradFp32, Test3DDivGrad2Fp32) { | |||
| std::vector<lite::Tensor *> outputs = {all_tensors[3], all_tensors[4]}; | |||
| auto param = PopulateArithmeticParameter(schema::PrimitiveType_DivGrad, inputs, outputs); | |||
| lite::InnerContext ctx; | |||
| ctx.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DivGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(param), &ctx, desc, nullptr); | |||
| kernel_obj->Run(); | |||
| float *output_ptr = reinterpret_cast<float *>(outputs[1]->MutableData()); | |||
| @@ -45,10 +45,13 @@ TEST_F(TestBiasGradFp32, BiasGradFp32) { | |||
| dw_tensor.SetData(output_data); | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_BiasGrad}; | |||
| lite::InnerContext ctx; | |||
| ctx.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_BiasGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(bias_param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(bias_param), &ctx, desc, nullptr); | |||
| kernel_obj->Run(); | |||
| @@ -58,19 +58,24 @@ TEST_F(TestBNGradFp32, BNGradFp32) { | |||
| auto var_tensor = CreateInTensor("././test_data/bngrad/save_var_3.bin", {1, 1, 1, channels}); | |||
| // prepare output tensors | |||
| lite::Tensor dx_tensor(TypeId::kNumberTypeFloat32, {batch, height, width, channels}); | |||
| dx_tensor.MallocData(); | |||
| ASSERT_EQ(dx_tensor.MallocData(), 0); | |||
| lite::Tensor dscale_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| dscale_tensor.MallocData(); | |||
| ASSERT_EQ(dscale_tensor.MallocData(), 0); | |||
| lite::Tensor dbias_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| dbias_tensor.MallocData(); | |||
| ASSERT_EQ(dbias_tensor.MallocData(), 0); | |||
| std::vector<lite::Tensor *> inputs = {dy_tensor, x_tensor, scale_tensor, mean_tensor, var_tensor}; | |||
| std::vector<lite::Tensor *> outputs = {&dx_tensor, &dscale_tensor, &dbias_tensor}; | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_BNGrad}; | |||
| lite::InnerContext ctx; | |||
| ctx.device_type_ = lite::DT_CPU; | |||
| ctx.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_BNGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(bn_param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(bn_param), &ctx, desc, nullptr); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel_obj->GetWorkspaceSize()); | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel_obj->Run(); | |||
| @@ -107,6 +112,7 @@ TEST_F(TestBNGradFp32, BNGradFp32) { | |||
| v->SetData(nullptr); | |||
| delete v; | |||
| } | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| delete kernel_obj; | |||
| MS_LOG(INFO) << "BNGradFp32 passed"; | |||
| } | |||
| @@ -114,6 +120,7 @@ TEST_F(TestBNGradFp32, BNGradFp32) { | |||
| TEST_F(TestBNGradFp32, BNTtrainFp32) { | |||
| auto bn_param = static_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter))); | |||
| bn_param->epsilon_ = 0.00001; | |||
| bn_param->momentum_ = 0.; | |||
| const int batch = 2; | |||
| const int channels = 3; | |||
| const int height = 4; | |||
| @@ -122,22 +129,22 @@ TEST_F(TestBNGradFp32, BNTtrainFp32) { | |||
| auto x_tensor = CreateInTensor("./test_data/bngrad/input_x_2_4_5_3.bin", {batch, height, width, channels}); | |||
| lite::Tensor scale_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| scale_tensor.MallocData(); | |||
| ASSERT_EQ(scale_tensor.MallocData(), 0); | |||
| auto scale = reinterpret_cast<float *>(scale_tensor.MutableData()); | |||
| std::fill(scale, scale + channels, 1.0f); | |||
| lite::Tensor bias_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| bias_tensor.MallocData(); | |||
| ASSERT_EQ(bias_tensor.MallocData(), 0); | |||
| auto bias = reinterpret_cast<float *>(bias_tensor.MutableData()); | |||
| std::fill(bias, bias + channels, 1.0f); | |||
| lite::Tensor mean_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| mean_tensor.MallocData(); | |||
| ASSERT_EQ(mean_tensor.MallocData(), 0); | |||
| auto mean = reinterpret_cast<float *>(mean_tensor.MutableData()); | |||
| std::fill(mean, mean + channels, 0.0f); | |||
| lite::Tensor var_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| var_tensor.MallocData(); | |||
| ASSERT_EQ(var_tensor.MallocData(), 0); | |||
| auto var = reinterpret_cast<float *>(var_tensor.MutableData()); | |||
| std::fill(var, var + channels, 1.0f); | |||
| @@ -146,11 +153,11 @@ TEST_F(TestBNGradFp32, BNTtrainFp32) { | |||
| lite::Tensor out_tensor(TypeId::kNumberTypeFloat32, {batch, height, width, channels}); | |||
| ASSERT_EQ(out_tensor.MallocData(), 0); | |||
| lite::Tensor run_mean_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| ASSERT_EQ(run_mean_tensor.MallocData(), 0); | |||
| lite::Tensor save_scale_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| ASSERT_EQ(save_scale_tensor.MallocData(), 0); | |||
| lite::Tensor run_var_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| ASSERT_EQ(run_var_tensor.MallocData(), 0); | |||
| lite::Tensor save_bias_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| ASSERT_EQ(save_bias_tensor.MallocData(), 0); | |||
| lite::Tensor save_mean_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| ASSERT_EQ(save_mean_tensor.MallocData(), 0); | |||
| @@ -158,7 +165,7 @@ TEST_F(TestBNGradFp32, BNTtrainFp32) { | |||
| lite::Tensor save_var_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); | |||
| ASSERT_EQ(save_var_tensor.MallocData(), 0); | |||
| std::vector<lite::Tensor *> outputs = {&out_tensor, &run_mean_tensor, &run_var_tensor, &save_mean_tensor, | |||
| std::vector<lite::Tensor *> outputs = {&out_tensor, &save_scale_tensor, &save_bias_tensor, &save_mean_tensor, | |||
| &save_var_tensor}; | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_FusedBatchNorm}; | |||
| @@ -170,26 +177,31 @@ TEST_F(TestBNGradFp32, BNTtrainFp32) { | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(bn_param), &context, desc, nullptr); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel_obj->GetWorkspaceSize()); | |||
| float *save_mean = reinterpret_cast<float *>(save_mean_tensor.MutableData()); | |||
| float *save_var = reinterpret_cast<float *>(save_var_tensor.MutableData()); | |||
| std::fill(save_mean, save_mean + channels, 0.f); | |||
| std::fill(save_var, save_var + channels, 0.f); | |||
| kernel_obj->train(); | |||
| kernel_obj->Run(); | |||
| float *run_mean = reinterpret_cast<float *>(run_mean_tensor.MutableData()); | |||
| float *run_var = reinterpret_cast<float *>(run_var_tensor.MutableData()); | |||
| std::cout << "================run_mean==============================\n"; | |||
| for (int i = 0; i < channels; i++) std::cout << run_mean[i] << " "; | |||
| std::cout << "================save_mean==============================\n"; | |||
| for (int i = 0; i < channels; i++) std::cout << save_mean[i] << " "; | |||
| std::cout << "\n"; | |||
| std::cout << "================run_var==============================\n"; | |||
| for (int i = 0; i < channels; i++) std::cout << run_var[i] << " "; | |||
| std::cout << "===============save_var==============================\n"; | |||
| for (int i = 0; i < channels; i++) std::cout << save_var[i] << " "; | |||
| std::cout << "\n"; | |||
| delete[] reinterpret_cast<float *>(x_tensor->MutableData()); | |||
| auto res = mindspore::lite::CompareRelativeOutput(run_mean, "./test_data/bngrad/running_mean_3.bin"); | |||
| auto res = mindspore::lite::CompareRelativeOutput(save_mean, "./test_data/bngrad/running_mean_3.bin"); | |||
| EXPECT_EQ(res, 0); | |||
| res = mindspore::lite::CompareRelativeOutput(run_var, "./test_data/bngrad/running_var_3.bin"); | |||
| res = mindspore::lite::CompareRelativeOutput(save_var, "./test_data/bngrad/running_var_3.bin"); | |||
| EXPECT_EQ(res, 0); | |||
| x_tensor->SetData(nullptr); | |||
| delete x_tensor; | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| delete kernel_obj; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -107,10 +107,15 @@ TEST_F(TestConvolutionGradFp32, ConvFp32FilterGrad) { | |||
| std::vector<lite::Tensor *> inputs = {&dy_tensor, &x_tensor}; | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradFilter}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), NULL, desc, nullptr); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), &context, desc, nullptr); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel->GetWorkspaceSize()); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel->Run(); | |||
| @@ -134,6 +139,7 @@ TEST_F(TestConvolutionGradFp32, ConvFp32FilterGrad) { | |||
| delete[] input_data; | |||
| delete[] dy_data; | |||
| delete[] dw_data; | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| delete kernel; | |||
| // delete conv_param; | |||
| dw_tensor.SetData(nullptr); | |||
| @@ -175,9 +181,15 @@ TEST_F(TestConvolutionGradFp32, ConvFp32InputGrad) { | |||
| printf("Calculating runtime cost...\n"); | |||
| uint64_t time_avg = 0; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradInput}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), NULL, desc, nullptr); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), &context, desc, nullptr); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel->GetWorkspaceSize()); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| @@ -203,6 +215,7 @@ TEST_F(TestConvolutionGradFp32, ConvFp32InputGrad) { | |||
| w_tensor.SetData(nullptr); | |||
| dy_tensor.SetData(nullptr); | |||
| dx_tensor.SetData(nullptr); | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| delete kernel; | |||
| // delete conv_param; | |||
| @@ -241,10 +254,15 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupFilterGrad) { | |||
| std::vector<lite::Tensor *> inputs = {&dy_tensor, &x_tensor}; | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradFilter}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), NULL, desc, nullptr); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), &context, desc, nullptr); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel->GetWorkspaceSize()); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel->Run(); | |||
| @@ -270,6 +288,7 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupFilterGrad) { | |||
| dw_tensor.SetData(nullptr); | |||
| x_tensor.SetData(nullptr); | |||
| dy_tensor.SetData(nullptr); | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| delete kernel; | |||
| // delete conv_param; | |||
| MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; | |||
| @@ -308,10 +327,15 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupInputGrad) { | |||
| printf("Calculating runtime cost...\n"); | |||
| uint64_t time_avg = 0; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradInput}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), NULL, desc, nullptr); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), &context, desc, nullptr); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel->GetWorkspaceSize()); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel->Run(); | |||
| @@ -338,6 +362,7 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupInputGrad) { | |||
| dy_tensor.SetData(nullptr); | |||
| delete kernel; | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| // delete conv_param; | |||
| MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; | |||
| } | |||
| @@ -375,9 +400,15 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationFilterGrad) { | |||
| std::vector<lite::Tensor *> inputs = {&dy_tensor, &x_tensor}; | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradFilter}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), NULL, desc, nullptr); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), &context, desc, nullptr); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel->GetWorkspaceSize()); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| @@ -403,6 +434,7 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationFilterGrad) { | |||
| dw_tensor.SetData(nullptr); | |||
| dy_tensor.SetData(nullptr); | |||
| x_tensor.SetData(nullptr); | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| delete kernel; | |||
| // delete conv_param; | |||
| MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; | |||
| @@ -441,14 +473,15 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationInputGrad) { | |||
| printf("Calculating runtime cost...\n"); | |||
| uint64_t time_avg = 0; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradInput}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), NULL, desc, nullptr); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel->Run(); | |||
| } | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), &context, desc, nullptr); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel->GetWorkspaceSize()); | |||
| int loop_count = 100; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| @@ -469,6 +502,7 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationInputGrad) { | |||
| dx_tensor.SetData(nullptr); | |||
| dy_tensor.SetData(nullptr); | |||
| w_tensor.SetData(nullptr); | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| delete kernel; | |||
| // delete conv_param; | |||
| MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; | |||
| @@ -515,6 +549,8 @@ TEST_F(TestConvolutionGradFp32, ConvGroupDilation) { | |||
| auto *kernel = new mindspore::kernel::ConvolutionTrainCPUKernel(reinterpret_cast<OpParameter *>(conv_param), inputs, | |||
| outputs, &context, 0); | |||
| kernel->Init(); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel->GetWorkspaceSize()); | |||
| kernel->train(); | |||
| EXPECT_EQ(kernel->is_train(), 1); | |||
| @@ -543,9 +579,208 @@ TEST_F(TestConvolutionGradFp32, ConvGroupDilation) { | |||
| x_tensor.SetData(nullptr); | |||
| y_tensor.SetData(nullptr); | |||
| w_tensor.SetData(nullptr); | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| delete kernel; | |||
| MS_LOG(INFO) << "TestConvolutionFp32 Filter Grad passed"; | |||
| } | |||
| TEST_F(TestConvolutionGradFp32, ConvFp32Dilation2Group2Stride2FilterGrad) { | |||
| // prepare stage | |||
| auto conv_param = static_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||
| conv_param->input_batch_ = 2; | |||
| conv_param->input_h_ = 32; | |||
| conv_param->input_w_ = 32; | |||
| conv_param->input_channel_ = 4; | |||
| conv_param->output_batch_ = 2; | |||
| conv_param->output_h_ = 15; | |||
| conv_param->output_w_ = 15; | |||
| conv_param->output_channel_ = 12; | |||
| conv_param->kernel_h_ = 3; | |||
| conv_param->kernel_w_ = 3; | |||
| conv_param->stride_h_ = 2; | |||
| conv_param->stride_w_ = 2; | |||
| conv_param->dilation_h_ = 2; | |||
| conv_param->dilation_w_ = 2; | |||
| conv_param->pad_u_ = 1; | |||
| conv_param->pad_l_ = 1; | |||
| conv_param->pad_r_ = 1; | |||
| conv_param->pad_d_ = 1; | |||
| conv_param->group_ = 2; | |||
| conv_param->act_type_ = ActType_No; | |||
| conv_param->thread_num_ = 1; | |||
| size_t dy_size; | |||
| std::string dy_path = "./test_data/conv/convfp32_dy_d2_g2_s2_2_12_15_15.bin"; | |||
| auto dy_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); | |||
| std::vector<int> dim_dy({2, 15, 15, 12}); | |||
| lite::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||
| dy_tensor.SetData(dy_data); | |||
| // runtime part | |||
| printf("Calculating runtime cost...\n"); | |||
| uint64_t time_avg = 0; | |||
| size_t output_data_size = | |||
| conv_param->output_channel_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | |||
| size_t input_size; | |||
| std::string input_path = "./test_data/conv/convfp32_input0_d2_g2_s2_2_4_32_32.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| std::vector<int> dim_x({2, 32, 32, 4}); | |||
| lite::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||
| x_tensor.SetData(input_data); | |||
| auto dw_data = new float[output_data_size]; | |||
| std::vector<int> dim_dw({12, 3, 3, 2}); | |||
| lite::Tensor dw_tensor(TypeId::kNumberTypeFloat32, dim_dw); | |||
| dw_tensor.SetData(dw_data); | |||
| std::vector<lite::Tensor *> inputs = {&dy_tensor, &x_tensor}; | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradFilter}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), &context, desc, nullptr); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel->GetWorkspaceSize()); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel->Run(); | |||
| } | |||
| int loop_count = 100; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| for (int i = 0; i < loop_count; i++) { | |||
| kernel->Run(); | |||
| } | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| auto cost = time_end - time_start; | |||
| time_avg = cost / loop_count; | |||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||
| std::string output_path = "./test_data/conv/convfp32_dw_d2_g2_s2_12_2_3_3.bin"; | |||
| auto res = lite::CompareRelativeOutput(dw_data, output_path); | |||
| EXPECT_EQ(res, 0); | |||
| delete[] input_data; | |||
| delete[] dy_data; | |||
| delete[] dw_data; | |||
| delete kernel; | |||
| // delete conv_param; | |||
| dw_tensor.SetData(nullptr); | |||
| x_tensor.SetData(nullptr); | |||
| dy_tensor.SetData(nullptr); | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; | |||
| } | |||
| TEST_F(TestConvolutionGradFp32, ConvGroup2Dilation2Stride2) { | |||
| // prepare stage | |||
| auto conv_param = static_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||
| conv_param->input_batch_ = 2; | |||
| conv_param->input_h_ = 32; | |||
| conv_param->input_w_ = 32; | |||
| conv_param->input_channel_ = 4; | |||
| conv_param->output_batch_ = 2; | |||
| conv_param->output_h_ = 15; | |||
| conv_param->output_w_ = 15; | |||
| conv_param->output_channel_ = 12; | |||
| conv_param->kernel_h_ = 3; | |||
| conv_param->kernel_w_ = 3; | |||
| conv_param->stride_h_ = 2; | |||
| conv_param->stride_w_ = 2; | |||
| conv_param->dilation_h_ = 2; | |||
| conv_param->dilation_w_ = 2; | |||
| conv_param->pad_u_ = 1; | |||
| conv_param->pad_l_ = 1; | |||
| conv_param->pad_r_ = 1; | |||
| conv_param->pad_d_ = 1; | |||
| conv_param->group_ = 2; | |||
| conv_param->act_type_ = ActType_No; | |||
| conv_param->thread_num_ = 1; | |||
| size_t dy_size; | |||
| std::string dy_path = "./test_data/conv/convfp32_dy_d2_g2_s2_2_12_15_15.bin"; | |||
| auto dy_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); | |||
| std::vector<int> dim_dy({2, 15, 15, 12}); | |||
| lite::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||
| dy_tensor.SetData(dy_data); | |||
| size_t w_size; | |||
| std::string w_path = "./test_data/conv/convfp32_w_d2_g2_s2_12_2_3_3.bin"; | |||
| auto w_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(w_path.c_str(), &w_size)); | |||
| std::vector<int> dim_w({12, 3, 3, 2}); | |||
| lite::Tensor w_tensor(TypeId::kNumberTypeFloat32, dim_w); | |||
| w_tensor.SetData(w_data); | |||
| size_t output_data_size = | |||
| conv_param->input_batch_ * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; | |||
| auto dx_data = new float[output_data_size]; | |||
| std::vector<int> dim_dx({2, 32, 32, 4}); | |||
| lite::Tensor dx_tensor(TypeId::kNumberTypeFloat32, dim_dx); | |||
| dx_tensor.SetData(dx_data); | |||
| std::vector<lite::Tensor *> inputs = {&dy_tensor, &w_tensor}; | |||
| std::vector<lite::Tensor *> outputs = {&dx_tensor}; | |||
| // runtime part | |||
| printf("Calculating runtime cost...\n"); | |||
| uint64_t time_avg = 0; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradInput}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), &context, desc, nullptr); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel->GetWorkspaceSize()); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel->Run(); | |||
| } | |||
| int loop_count = 100; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| for (int i = 0; i < loop_count; i++) { | |||
| kernel->Run(); | |||
| } | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| auto cost = time_end - time_start; | |||
| time_avg = cost / loop_count; | |||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||
| std::string output_path = "./test_data/conv/convfp32_inputdx_d2_g2_s2_2_4_32_32.bin"; | |||
| auto res = lite::CompareRelativeOutput(dx_data, output_path); | |||
| EXPECT_EQ(res, 0); | |||
| delete[] dx_data; | |||
| delete[] w_data; | |||
| delete[] dy_data; | |||
| dx_tensor.SetData(nullptr); | |||
| dy_tensor.SetData(nullptr); | |||
| w_tensor.SetData(nullptr); | |||
| delete kernel; | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,634 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <vector> | |||
| // #include "utils/log_adapter.h" | |||
| #include "common/common_test.h" | |||
| #include "src/common/file_utils.h" | |||
| #include "src/common/file_utils_ext.h" | |||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/deconvolution_grad_filter.h" | |||
| #include "mindspore/lite/nnacl/conv_parameter.h" | |||
| #include "mindspore/lite/src/kernel_registry.h" | |||
| namespace mindspore { | |||
| class TestDeConvolutionGradFp32 : public mindspore::CommonTest { | |||
| public: | |||
| TestDeConvolutionGradFp32() {} | |||
| }; | |||
| TEST_F(TestDeConvolutionGradFp32, DeConvFp32FilterGrad) { | |||
| // prepare stage | |||
| auto conv_param = static_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||
| conv_param->input_batch_ = 2; | |||
| conv_param->input_h_ = 32; | |||
| conv_param->input_w_ = 32; | |||
| conv_param->input_channel_ = 3; | |||
| conv_param->output_batch_ = 2; | |||
| conv_param->output_h_ = 63; | |||
| conv_param->output_w_ = 63; | |||
| conv_param->output_channel_ = 9; | |||
| conv_param->kernel_h_ = 3; | |||
| conv_param->kernel_w_ = 3; | |||
| conv_param->stride_h_ = 2; | |||
| conv_param->stride_w_ = 2; | |||
| conv_param->dilation_h_ = 1; | |||
| conv_param->dilation_w_ = 1; | |||
| conv_param->pad_u_ = 1; | |||
| conv_param->pad_l_ = 1; | |||
| conv_param->pad_r_ = 1; | |||
| conv_param->pad_d_ = 1; | |||
| conv_param->group_ = 1; | |||
| conv_param->act_type_ = ActType_No; | |||
| conv_param->thread_num_ = 1; | |||
| size_t dy_size; | |||
| std::string dy_path = "./test_data/deconv/deconvfp32_dy_2_9_63_63.bin"; | |||
| auto dy_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); | |||
| std::vector<int> dim_dy({2, 63, 63, 9}); | |||
| lite::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||
| dy_tensor.SetData(dy_data); | |||
| // runtime part | |||
| printf("Calculating runtime cost...\n"); | |||
| uint64_t time_avg = 0; | |||
| size_t output_data_size = | |||
| conv_param->output_channel_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | |||
| size_t input_size; | |||
| std::string input_path = "./test_data/deconv/deconvfp32_input0_2_3_32_32.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| std::vector<int> dim_x({2, 32, 32, 3}); | |||
| lite::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||
| x_tensor.SetData(input_data); | |||
| auto dw_data = new float[output_data_size]; | |||
| std::vector<int> dim_dw({3, 3, 3, 9}); | |||
| lite::Tensor dw_tensor(TypeId::kNumberTypeFloat32, dim_dw); | |||
| dw_tensor.SetData(dw_data); | |||
| std::vector<lite::Tensor *> inputs = {&dy_tensor, &x_tensor}; | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DeConv2DGradFilter}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), &context, desc, nullptr); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel->GetWorkspaceSize()); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel->Run(); | |||
| } | |||
| int loop_count = 100; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| for (int i = 0; i < loop_count; i++) { | |||
| kernel->Run(); | |||
| } | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| auto cost = time_end - time_start; | |||
| time_avg = cost / loop_count; | |||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||
| std::string output_path = "./test_data/deconv/deconvfp32_dw_9_3_3_3.bin"; | |||
| auto res = lite::CompareRelativeOutput(dw_data, output_path); | |||
| EXPECT_EQ(res, 0); | |||
| delete[] input_data; | |||
| delete[] dy_data; | |||
| delete[] dw_data; | |||
| delete kernel; | |||
| // delete conv_param; | |||
| dw_tensor.SetData(nullptr); | |||
| x_tensor.SetData(nullptr); | |||
| dy_tensor.SetData(nullptr); | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| MS_LOG(INFO) << "TestDeConvolutionGradFp32 Filter Grad passed"; | |||
| } | |||
| TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2FilterGrad) { | |||
| // prepare stage | |||
| auto conv_param = static_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||
| conv_param->input_batch_ = 2; | |||
| conv_param->input_h_ = 32; | |||
| conv_param->input_w_ = 32; | |||
| conv_param->input_channel_ = 3; | |||
| conv_param->output_batch_ = 2; | |||
| conv_param->output_h_ = 65; | |||
| conv_param->output_w_ = 65; | |||
| conv_param->output_channel_ = 9; | |||
| conv_param->kernel_h_ = 3; | |||
| conv_param->kernel_w_ = 3; | |||
| conv_param->stride_h_ = 2; | |||
| conv_param->stride_w_ = 2; | |||
| conv_param->dilation_h_ = 2; | |||
| conv_param->dilation_w_ = 2; | |||
| conv_param->pad_u_ = 1; | |||
| conv_param->pad_l_ = 1; | |||
| conv_param->pad_r_ = 1; | |||
| conv_param->pad_d_ = 1; | |||
| conv_param->group_ = 1; | |||
| conv_param->act_type_ = ActType_No; | |||
| conv_param->thread_num_ = 1; | |||
| size_t dy_size; | |||
| std::string dy_path = "./test_data/deconv/deconvfp32_dy_d2_2_9_65_65.bin"; | |||
| auto dy_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); | |||
| std::vector<int> dim_dy({2, 65, 65, 9}); | |||
| lite::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||
| dy_tensor.SetData(dy_data); | |||
| // runtime part | |||
| printf("Calculating runtime cost...\n"); | |||
| uint64_t time_avg = 0; | |||
| size_t output_data_size = | |||
| conv_param->output_channel_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | |||
| size_t input_size; | |||
| std::string input_path = "./test_data/deconv/deconvfp32_input0_d2_2_3_32_32.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| std::vector<int> dim_x({2, 32, 32, 3}); | |||
| lite::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||
| x_tensor.SetData(input_data); | |||
| auto dw_data = new float[output_data_size]; | |||
| std::vector<int> dim_dw({9, 3, 3, 3}); | |||
| lite::Tensor dw_tensor(TypeId::kNumberTypeFloat32, dim_dw); | |||
| dw_tensor.SetData(dw_data); | |||
| std::vector<lite::Tensor *> inputs = {&dy_tensor, &x_tensor}; | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DeConv2DGradFilter}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), &context, desc, nullptr); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel->GetWorkspaceSize()); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel->Run(); | |||
| } | |||
| int loop_count = 100; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| for (int i = 0; i < loop_count; i++) { | |||
| kernel->Run(); | |||
| } | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| auto cost = time_end - time_start; | |||
| time_avg = cost / loop_count; | |||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||
| std::string output_path = "./test_data/deconv/deconvfp32_dw_d2_9_3_3_3.bin"; | |||
| auto res = lite::CompareRelativeOutput(dw_data, output_path); | |||
| EXPECT_EQ(res, 0); | |||
| delete[] input_data; | |||
| delete[] dy_data; | |||
| delete[] dw_data; | |||
| delete kernel; | |||
| // delete conv_param; | |||
| dw_tensor.SetData(nullptr); | |||
| x_tensor.SetData(nullptr); | |||
| dy_tensor.SetData(nullptr); | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| MS_LOG(INFO) << "TestDeConvolutionGradFp32 Filter Grad passed"; | |||
| } | |||
| TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2Group3FilterGrad) { | |||
| // prepare stage | |||
| auto conv_param = static_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||
| conv_param->input_batch_ = 2; | |||
| conv_param->input_h_ = 32; | |||
| conv_param->input_w_ = 32; | |||
| conv_param->input_channel_ = 3; | |||
| conv_param->output_batch_ = 2; | |||
| conv_param->output_h_ = 65; | |||
| conv_param->output_w_ = 65; | |||
| conv_param->output_channel_ = 9; | |||
| conv_param->kernel_h_ = 3; | |||
| conv_param->kernel_w_ = 3; | |||
| conv_param->stride_h_ = 2; | |||
| conv_param->stride_w_ = 2; | |||
| conv_param->dilation_h_ = 2; | |||
| conv_param->dilation_w_ = 2; | |||
| conv_param->pad_u_ = 1; | |||
| conv_param->pad_l_ = 1; | |||
| conv_param->pad_r_ = 1; | |||
| conv_param->pad_d_ = 1; | |||
| conv_param->group_ = 3; | |||
| conv_param->act_type_ = ActType_No; | |||
| conv_param->thread_num_ = 1; | |||
| size_t dy_size; | |||
| std::string dy_path = "./test_data/deconv/deconvfp32_dy_d2_g3_2_9_65_65.bin"; | |||
| auto dy_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); | |||
| std::vector<int> dim_dy({2, 65, 65, 9}); | |||
| lite::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||
| dy_tensor.SetData(dy_data); | |||
| // runtime part | |||
| printf("Calculating runtime cost...\n"); | |||
| uint64_t time_avg = 0; | |||
| size_t output_data_size = | |||
| conv_param->output_channel_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | |||
| size_t input_size; | |||
| std::string input_path = "./test_data/deconv/deconvfp32_input0_d2_g3_2_3_32_32.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| std::vector<int> dim_x({2, 32, 32, 3}); | |||
| lite::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||
| x_tensor.SetData(input_data); | |||
| auto dw_data = new float[output_data_size]; | |||
| std::vector<int> dim_dw({3, 3, 3, 3}); | |||
| lite::Tensor dw_tensor(TypeId::kNumberTypeFloat32, dim_dw); | |||
| dw_tensor.SetData(dw_data); | |||
| std::vector<lite::Tensor *> inputs = {&dy_tensor, &x_tensor}; | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DeConv2DGradFilter}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), &context, desc, nullptr); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel->GetWorkspaceSize()); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel->Run(); | |||
| } | |||
| int loop_count = 100; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| for (int i = 0; i < loop_count; i++) { | |||
| kernel->Run(); | |||
| } | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| auto cost = time_end - time_start; | |||
| time_avg = cost / loop_count; | |||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||
| std::string output_path = "./test_data/deconv/deconvfp32_dw_d2_g3_3_3_3_3.bin"; | |||
| auto res = lite::CompareRelativeOutput(dw_data, output_path); | |||
| EXPECT_EQ(res, 0); | |||
| delete[] input_data; | |||
| delete[] dy_data; | |||
| delete[] dw_data; | |||
| delete kernel; | |||
| // delete conv_param; | |||
| dw_tensor.SetData(nullptr); | |||
| x_tensor.SetData(nullptr); | |||
| dy_tensor.SetData(nullptr); | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| MS_LOG(INFO) << "TestDeConvolutionGradFp32 Filter Grad passed"; | |||
| } | |||
| TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2Group3Stride1FilterGrad) { | |||
| // prepare stage | |||
| auto conv_param = static_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||
| conv_param->input_batch_ = 2; | |||
| conv_param->input_h_ = 32; | |||
| conv_param->input_w_ = 32; | |||
| conv_param->input_channel_ = 3; | |||
| conv_param->output_batch_ = 2; | |||
| conv_param->output_h_ = 34; | |||
| conv_param->output_w_ = 34; | |||
| conv_param->output_channel_ = 9; | |||
| conv_param->kernel_h_ = 3; | |||
| conv_param->kernel_w_ = 3; | |||
| conv_param->stride_h_ = 1; | |||
| conv_param->stride_w_ = 1; | |||
| conv_param->dilation_h_ = 2; | |||
| conv_param->dilation_w_ = 2; | |||
| conv_param->pad_u_ = 1; | |||
| conv_param->pad_l_ = 1; | |||
| conv_param->pad_r_ = 1; | |||
| conv_param->pad_d_ = 1; | |||
| conv_param->group_ = 3; | |||
| conv_param->act_type_ = ActType_No; | |||
| conv_param->thread_num_ = 1; | |||
| size_t dy_size; | |||
| std::string dy_path = "./test_data/deconv/deconvfp32_dy_d2_g3_s1_2_9_34_34.bin"; | |||
| auto dy_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); | |||
| std::vector<int> dim_dy({2, 34, 34, 9}); | |||
| lite::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||
| dy_tensor.SetData(dy_data); | |||
| // runtime part | |||
| printf("Calculating runtime cost...\n"); | |||
| uint64_t time_avg = 0; | |||
| size_t output_data_size = | |||
| conv_param->output_channel_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | |||
| size_t input_size; | |||
| std::string input_path = "./test_data/deconv/deconvfp32_input0_d2_g3_s1_2_3_32_32.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| std::vector<int> dim_x({2, 32, 32, 3}); | |||
| lite::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||
| x_tensor.SetData(input_data); | |||
| auto dw_data = new float[output_data_size]; | |||
| std::vector<int> dim_dw({3, 3, 3, 3}); | |||
| lite::Tensor dw_tensor(TypeId::kNumberTypeFloat32, dim_dw); | |||
| dw_tensor.SetData(dw_data); | |||
| std::vector<lite::Tensor *> inputs = {&dy_tensor, &x_tensor}; | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DeConv2DGradFilter}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), &context, desc, nullptr); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel->GetWorkspaceSize()); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel->Run(); | |||
| } | |||
| int loop_count = 100; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| for (int i = 0; i < loop_count; i++) { | |||
| kernel->Run(); | |||
| } | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| auto cost = time_end - time_start; | |||
| time_avg = cost / loop_count; | |||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||
| std::string output_path = "./test_data/deconv/deconvfp32_dw_d2_g3_s1_3_3_3_3.bin"; | |||
| auto res = lite::CompareRelativeOutput(dw_data, output_path); | |||
| EXPECT_EQ(res, 0); | |||
| delete[] input_data; | |||
| delete[] dy_data; | |||
| delete[] dw_data; | |||
| delete kernel; | |||
| // delete conv_param; | |||
| dw_tensor.SetData(nullptr); | |||
| x_tensor.SetData(nullptr); | |||
| dy_tensor.SetData(nullptr); | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| MS_LOG(INFO) << "TestDeConvolutionGradFp32 Filter Grad passed"; | |||
| } | |||
| TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2Group2Stride2FilterGrad) { | |||
| // prepare stage | |||
| auto conv_param = static_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||
| conv_param->input_batch_ = 2; | |||
| conv_param->input_h_ = 32; | |||
| conv_param->input_w_ = 32; | |||
| conv_param->input_channel_ = 4; | |||
| conv_param->output_batch_ = 2; | |||
| conv_param->output_h_ = 65; | |||
| conv_param->output_w_ = 65; | |||
| conv_param->output_channel_ = 12; | |||
| conv_param->kernel_h_ = 3; | |||
| conv_param->kernel_w_ = 3; | |||
| conv_param->stride_h_ = 2; | |||
| conv_param->stride_w_ = 2; | |||
| conv_param->dilation_h_ = 2; | |||
| conv_param->dilation_w_ = 2; | |||
| conv_param->pad_u_ = 1; | |||
| conv_param->pad_l_ = 1; | |||
| conv_param->pad_r_ = 1; | |||
| conv_param->pad_d_ = 1; | |||
| conv_param->group_ = 2; | |||
| conv_param->act_type_ = ActType_No; | |||
| conv_param->thread_num_ = 1; | |||
| size_t dy_size; | |||
| std::string dy_path = "./test_data/deconv/deconvfp32_dy_d2_g2_s2_2_12_65_65.bin"; | |||
| auto dy_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); | |||
| std::vector<int> dim_dy({2, 65, 65, 12}); | |||
| lite::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||
| dy_tensor.SetData(dy_data); | |||
| // runtime part | |||
| printf("Calculating runtime cost...\n"); | |||
| uint64_t time_avg = 0; | |||
| size_t output_data_size = | |||
| conv_param->output_channel_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | |||
| size_t input_size; | |||
| std::string input_path = "./test_data/deconv/deconvfp32_input0_d2_g2_s2_2_4_32_32.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| std::vector<int> dim_x({2, 32, 32, 4}); | |||
| lite::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||
| x_tensor.SetData(input_data); | |||
| auto dw_data = new float[output_data_size]; | |||
| std::vector<int> dim_dw({6, 3, 3, 4}); | |||
| lite::Tensor dw_tensor(TypeId::kNumberTypeFloat32, dim_dw); | |||
| dw_tensor.SetData(dw_data); | |||
| std::vector<lite::Tensor *> inputs = {&dy_tensor, &x_tensor}; | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DeConv2DGradFilter}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), &context, desc, nullptr); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel->GetWorkspaceSize()); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel->Run(); | |||
| } | |||
| int loop_count = 100; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| for (int i = 0; i < loop_count; i++) { | |||
| kernel->Run(); | |||
| } | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| auto cost = time_end - time_start; | |||
| time_avg = cost / loop_count; | |||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||
| std::string output_path = "./test_data/deconv/deconvfp32_dw_d2_g2_s2_6_4_3_3.bin"; | |||
| auto res = lite::CompareRelativeOutput(dw_data, output_path); | |||
| EXPECT_EQ(res, 0); | |||
| delete[] input_data; | |||
| delete[] dy_data; | |||
| delete[] dw_data; | |||
| delete kernel; | |||
| // delete conv_param; | |||
| dw_tensor.SetData(nullptr); | |||
| x_tensor.SetData(nullptr); | |||
| dy_tensor.SetData(nullptr); | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| MS_LOG(INFO) << "TestDeConvolutionGradFp32 Filter Grad passed"; | |||
| } | |||
| TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2Group12Stride2FilterGrad) { | |||
| // prepare stage | |||
| auto conv_param = static_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||
| conv_param->input_batch_ = 2; | |||
| conv_param->input_h_ = 32; | |||
| conv_param->input_w_ = 32; | |||
| conv_param->input_channel_ = 12; | |||
| conv_param->output_batch_ = 2; | |||
| conv_param->output_h_ = 65; | |||
| conv_param->output_w_ = 65; | |||
| conv_param->output_channel_ = 12; | |||
| conv_param->kernel_h_ = 3; | |||
| conv_param->kernel_w_ = 3; | |||
| conv_param->stride_h_ = 2; | |||
| conv_param->stride_w_ = 2; | |||
| conv_param->dilation_h_ = 2; | |||
| conv_param->dilation_w_ = 2; | |||
| conv_param->pad_u_ = 1; | |||
| conv_param->pad_l_ = 1; | |||
| conv_param->pad_r_ = 1; | |||
| conv_param->pad_d_ = 1; | |||
| conv_param->group_ = 12; | |||
| conv_param->act_type_ = ActType_No; | |||
| conv_param->thread_num_ = 1; | |||
| size_t dy_size; | |||
| std::string dy_path = "./test_data/deconv/deconvfp32_dy_d2_g12_s2_2_12_65_65.bin"; | |||
| auto dy_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dy_path.c_str(), &dy_size)); | |||
| std::vector<int> dim_dy({2, 65, 65, 12}); | |||
| lite::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); | |||
| dy_tensor.SetData(dy_data); | |||
| // runtime part | |||
| printf("Calculating runtime cost...\n"); | |||
| uint64_t time_avg = 0; | |||
| size_t output_data_size = | |||
| conv_param->output_channel_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | |||
| size_t input_size; | |||
| std::string input_path = "./test_data/deconv/deconvfp32_input0_d2_g12_s2_2_12_32_32.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| std::vector<int> dim_x({2, 32, 32, 12}); | |||
| lite::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); | |||
| x_tensor.SetData(input_data); | |||
| auto dw_data = new float[output_data_size]; | |||
| std::vector<int> dim_dw({1, 3, 3, 12}); | |||
| lite::Tensor dw_tensor(TypeId::kNumberTypeFloat32, dim_dw); | |||
| dw_tensor.SetData(dw_data); | |||
| std::vector<lite::Tensor *> inputs = {&dy_tensor, &x_tensor}; | |||
| std::vector<lite::Tensor *> outputs = {&dw_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DeConv2DGradFilter}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(conv_param), &context, desc, nullptr); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel->GetWorkspaceSize()); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel->Run(); | |||
| } | |||
| int loop_count = 100; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| for (int i = 0; i < loop_count; i++) { | |||
| kernel->Run(); | |||
| } | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| auto cost = time_end - time_start; | |||
| time_avg = cost / loop_count; | |||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||
| std::string output_path = "./test_data/deconv/deconvfp32_dw_d2_g12_s2_12_1_3_3.bin"; | |||
| auto res = lite::CompareRelativeOutput(dw_data, output_path); | |||
| EXPECT_EQ(res, 0); | |||
| delete[] input_data; | |||
| delete[] dy_data; | |||
| delete[] dw_data; | |||
| delete kernel; | |||
| // delete conv_param; | |||
| dw_tensor.SetData(nullptr); | |||
| x_tensor.SetData(nullptr); | |||
| dy_tensor.SetData(nullptr); | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| MS_LOG(INFO) << "TestDeConvolutionGradFp32 Filter Grad passed"; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -26,12 +26,13 @@ | |||
| #include "mindspore/lite/include/train_model.h" | |||
| #include "common/common_test.h" | |||
| #include "include/train_session.h" | |||
| // #include "include/lite_session.h" | |||
| #include "include/context.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/common/file_utils.h" | |||
| #include "src/common/file_utils_ext.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/arm/fp32_grad/convolution.h" | |||
| namespace mindspore { | |||
| class NetworkTest : public mindspore::CommonTest { | |||
| @@ -39,6 +40,9 @@ class NetworkTest : public mindspore::CommonTest { | |||
| NetworkTest() {} | |||
| }; | |||
| int32_t runNet(mindspore::session::LiteSession *session, const std::string &in, const std::string &out, | |||
| const char *tensor_name, bool debug = false); | |||
| // INPUT(0) | |||
| // V | |||
| // +-------------+ | |||
| @@ -352,15 +356,13 @@ TEST_F(NetworkTest, tuning_layer) { | |||
| ASSERT_NE(nullptr, model); | |||
| meta_graph.reset(); | |||
| content = nullptr; | |||
| lite::InnerContext context; | |||
| lite::Context context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.cpu_bind_mode_ = lite::NO_BIND; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| auto session = new session::TrainSession(); | |||
| auto session = session::TrainSession::CreateSession(&context); | |||
| ASSERT_NE(nullptr, session); | |||
| session->Init(&context); | |||
| auto ret = session->CompileGraph(model); | |||
| auto ret = session->CompileTrainGraph(model); | |||
| ASSERT_EQ(lite::RET_OK, ret); | |||
| session->Train(); | |||
| session->Train(); // Just double check that calling Train twice does not cause a problem | |||
| @@ -469,59 +471,67 @@ int32_t fileIterator(mindspore::session::TrainSession *session, const std::strin | |||
| } | |||
| void replaceExt(const std::string &src, std::string *dst) { *dst = src.substr(0, src.find_last_of('.')) + ".emb"; } | |||
| int32_t runNet(mindspore::lite::LiteSession *session, const std::string &in, const std::string &out, | |||
| const char *tensor_name) { | |||
| int32_t runNet(mindspore::session::LiteSession *session, const std::string &in, const std::string &out, | |||
| const char *tensor_name, bool debug) { | |||
| // setup input | |||
| auto inputs = session->GetInputs(); | |||
| auto inTensor = inputs.at(0); | |||
| float *data = reinterpret_cast<float *>(inTensor->MutableData()); | |||
| size_t input_size; | |||
| float *in_buf = reinterpret_cast<float *>(lite::ReadFile(in.c_str(), &input_size)); | |||
| auto input_data = reinterpret_cast<float *>(in_buf); | |||
| std::copy(input_data, input_data + inTensor->ElementsNum(), data); | |||
| std::cout << "==============Input===========================" << std::endl; | |||
| for (int i = 0; i < 10; i++) { | |||
| std::cout << data[i] << ", "; | |||
| } | |||
| std::cout << std::endl; | |||
| delete[] in_buf; | |||
| // execute network | |||
| session->RunGraph(); | |||
| // compare outputs | |||
| auto output = session->GetOutputByTensorName(tensor_name); | |||
| float *output_data = reinterpret_cast<float *>(output->MutableData()); | |||
| if (output != nullptr) { | |||
| float *output_data = reinterpret_cast<float *>(output->MutableData()); | |||
| // compare outputs | |||
| if (debug) { | |||
| std::cout << "==============Output===========================" << std::endl; | |||
| for (int i = 0; i < 10; i++) { | |||
| std::cout << output_data[i] << ", "; | |||
| } | |||
| std::cout << std::endl; | |||
| } | |||
| return mindspore::lite::CompareRelativeOutput(output_data, out); | |||
| } | |||
| return mindspore::lite::CompareRelativeOutput(output_data, out); | |||
| return lite::RET_ERROR; | |||
| } | |||
| TEST_F(NetworkTest, efficient_net) { | |||
| char *buf = nullptr; | |||
| size_t net_size = 0; | |||
| // std::string net = "./test_data/nets/efficientnet_b0_f.ms"; | |||
| std::string net = "./test_data/nets/effnetb0_fwd_nofuse.ms"; | |||
| ReadFile(net.c_str(), &net_size, &buf); | |||
| auto model = lite::TrainModel::Import(buf, net_size); | |||
| delete[] buf; | |||
| auto context = new lite::InnerContext; | |||
| auto context = new lite::Context; | |||
| context->device_type_ = lite::DT_CPU; | |||
| context->cpu_bind_mode_ = lite::NO_BIND; | |||
| context->thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context->Init()); | |||
| auto session = new mindspore::session::TrainSession(); | |||
| auto session = session::TrainSession::CreateSession(context); | |||
| ASSERT_NE(session, nullptr); | |||
| auto ret = session->Init(context); | |||
| ASSERT_EQ(lite::RET_OK, ret); | |||
| ret = session->CompileGraph(model); | |||
| auto ret = session->CompileTrainGraph(model); | |||
| ASSERT_EQ(lite::RET_OK, ret); | |||
| session->Eval(); | |||
| std::string in = "./test_data/nets/effNet_input_x_1_3_224_224.bin"; | |||
| std::string out = "./test_data/nets/effNet_output_y_1_1000.bin"; | |||
| auto res = runNet(session, in, out, "631"); | |||
| ASSERT_EQ(res, 0); | |||
| auto res = runNet(session, in, out, "650"); | |||
| delete session; | |||
| delete context; | |||
| ASSERT_EQ(res, 0); | |||
| } | |||
| TEST_F(NetworkTest, lenetnet) { | |||
| @@ -536,19 +546,105 @@ TEST_F(NetworkTest, lenetnet) { | |||
| context->cpu_bind_mode_ = lite::NO_BIND; | |||
| context->thread_num_ = 1; | |||
| auto session = new mindspore::session::TrainSession(); | |||
| // check registration | |||
| mindspore::lite::KernelRegistry *reg = mindspore::lite::KernelRegistry::GetInstance(); | |||
| mindspore::kernel::KernelKey desc1 = {mindspore::kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, | |||
| mindspore::schema::PrimitiveType_Conv2D}; | |||
| mindspore::kernel::KernelKey desc2 = {mindspore::kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, | |||
| mindspore::schema::PrimitiveType_DepthwiseConv2D}; | |||
| auto regb1 = reg->GetCreator(desc1); | |||
| auto regb2 = reg->GetCreator(desc2); | |||
| ASSERT_EQ(regb1 == mindspore::kernel::CpuConvTrainFp32KernelCreator, false); | |||
| auto session = session::TrainSession::CreateSession(context); | |||
| ASSERT_NE(session, nullptr); | |||
| auto ret = session->Init(context); | |||
| ASSERT_EQ(lite::RET_OK, ret); | |||
| ret = session->CompileGraph(model); | |||
| auto ret = session->CompileTrainGraph(model); | |||
| ASSERT_EQ(lite::RET_OK, ret); | |||
| session->Eval(); | |||
| auto rega1 = reg->GetCreator(desc1); | |||
| auto rega2 = reg->GetCreator(desc2); | |||
| ASSERT_EQ(regb1, rega1); | |||
| ASSERT_EQ(regb2, rega2); | |||
| ASSERT_EQ(rega1 == mindspore::kernel::CpuConvTrainFp32KernelCreator, false); | |||
| // end of check registration | |||
| session->Eval(); | |||
| std::string in = "./test_data/nets/x_lenet.bin"; | |||
| std::string out = "./test_data/nets/y_lenet.bin"; | |||
| auto res = runNet(session, in, out, "24"); | |||
| delete session; | |||
| delete context; | |||
| ASSERT_EQ(res, 0); | |||
| } | |||
| #if 0 | |||
| TEST_F(NetworkTest, retina_net) { | |||
| char *buf = nullptr; | |||
| size_t net_size = 0; | |||
| std::string net = "./test_data/nets/retinaface1009.ms"; | |||
| ReadFile(net.c_str(), &net_size, &buf); | |||
| // auto model = lite::TrainModel::Import(buf, net_size); | |||
| auto model = lite::Model::Import(buf, net_size); | |||
| delete[] buf; | |||
| auto context = new lite::Context; | |||
| context->device_type_ = lite::DT_CPU; | |||
| context->cpu_bind_mode_ = lite::NO_BIND; | |||
| context->thread_num_ = 1; | |||
| // auto session = session::TrainSession::CreateSession(context); | |||
| auto session = session::LiteSession::CreateSession(context); | |||
| ASSERT_NE(session, nullptr); | |||
| auto ret = session->CompileGraph(model); | |||
| ASSERT_EQ(lite::RET_OK, ret); | |||
| // session->Eval(); | |||
| std::string in = "./test_data/nets/retinaface_input.f32"; | |||
| std::cout << "----- Output 0 -----" << std::endl; | |||
| std::string out = "./test_data/nets/retinaface_out_0.f32"; | |||
| auto res = runNet(session, in, out, "448", true); | |||
| ASSERT_EQ(res, 0); | |||
| std::cout << "----- Output 1 -----" << std::endl; | |||
| out = "./test_data/nets/retinaface_out_1.f32"; | |||
| res = runNet(session, in, out, "435", true); | |||
| ASSERT_EQ(res, 0); | |||
| std::cout << "----- Output 2 -----" << std::endl; | |||
| out = "./test_data/nets/retinaface_out_2.f32"; | |||
| res = runNet(session, in, out, "421", true); | |||
| ASSERT_EQ(res, 0); | |||
| delete session; | |||
| delete context; | |||
| } | |||
| #endif | |||
| TEST_F(NetworkTest, mobileface_net) { | |||
| char *buf = nullptr; | |||
| size_t net_size = 0; | |||
| std::string net = "./test_data/nets/mobilefacenet0924.ms"; | |||
| ReadFile(net.c_str(), &net_size, &buf); | |||
| // auto model = lite::TrainModel::Import(buf, net_size); | |||
| auto model = lite::Model::Import(buf, net_size); | |||
| delete[] buf; | |||
| auto context = new lite::Context; | |||
| context->device_type_ = lite::DT_CPU; | |||
| context->cpu_bind_mode_ = lite::NO_BIND; | |||
| context->thread_num_ = 1; | |||
| // auto session = session::TrainSession::CreateSession(context); | |||
| auto session = session::LiteSession::CreateSession(context); | |||
| ASSERT_NE(session, nullptr); | |||
| auto ret = session->CompileGraph(model); | |||
| ASSERT_EQ(lite::RET_OK, ret); | |||
| // session->Eval(); | |||
| std::string in = "./test_data/nets/facenet_input.f32"; | |||
| std::string out = "./test_data/nets/facenet_output.f32"; | |||
| auto res = runNet(session, in, out, "354", true); | |||
| ASSERT_EQ(res, 0); | |||
| delete model; | |||
| delete session; | |||
| delete context; | |||
| } | |||
| @@ -20,12 +20,12 @@ | |||
| #include "mindspore/lite/include/context.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "common/common_test.h" | |||
| #include "mindspore/lite/src/kernel_registry.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/common/file_utils.h" | |||
| #include "src/common/file_utils_ext.h" | |||
| #include "src/runtime/kernel/arm/fp32_grad/pooling_grad.h" | |||
| #include "nnacl/fp32_grad/pooling_grad.h" | |||
| #include "src/runtime/kernel/arm/fp32_grad/pooling_grad.h" | |||
| #include "mindspore/lite/src/kernel_registry.h" | |||
| namespace mindspore { | |||
| class TestPoolingGradFp32 : public mindspore::CommonTest { | |||
| @@ -78,13 +78,13 @@ TEST_F(TestPoolingGradFp32, AvgPoolingGradFp32) { | |||
| auto output_data = new float[output_data_size]; | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| AvgPoolingGrad(input_data, output_data, pooling_param); | |||
| AvgPoolingGrad(input_data, output_data, pooling_param, 1); | |||
| } | |||
| int loop_count = 100; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| for (int i = 0; i < loop_count; i++) { | |||
| AvgPoolingGrad(input_data, output_data, pooling_param); | |||
| AvgPoolingGrad(input_data, output_data, pooling_param, 1); | |||
| } | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| auto cost = time_end - time_start; | |||
| @@ -140,10 +140,14 @@ TEST_F(TestPoolingGradFp32, AvgPoolingKernelGradFp32) { | |||
| dx_tensor.SetData(output_data); | |||
| std::vector<lite::Tensor *> outputs = {&dx_tensor}; | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(pooling_param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(pooling_param), &context, desc, nullptr); | |||
| kernel_obj->Run(); | |||
| @@ -201,10 +205,14 @@ TEST_F(TestPoolingGradFp32, AvgPoolingBatchGradFp32) { | |||
| auto output_data = reinterpret_cast<float *>(dx_tensor.MutableData()); | |||
| std::vector<lite::Tensor *> outputs = {&dx_tensor}; | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(pooling_param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(pooling_param), &context, desc, nullptr); | |||
| kernel_obj->Run(); | |||
| @@ -259,17 +267,22 @@ TEST_F(TestPoolingGradFp32, AvgPoolGradStride2Fp32) { | |||
| float *out_data = static_cast<float *>(out_tensor.MutableData()); | |||
| std::vector<lite::Tensor *> inputs = {&yt_tensor, &x_tensor}; | |||
| std::vector<lite::Tensor *> outputs = {&out_tensor}; | |||
| // ---------------------------------------- | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey pool_desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; | |||
| auto pool_creator = lite::KernelRegistry::GetInstance()->GetCreator(pool_desc); | |||
| auto kernel = pool_creator(inputs, outputs, reinterpret_cast<OpParameter *>(pool), NULL, pool_desc, nullptr); | |||
| auto kernel = pool_creator(inputs, outputs, reinterpret_cast<OpParameter *>(pool), &context, pool_desc, nullptr); | |||
| kernel->Init(); | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| kernel->Run(); | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| printf("single thread running time : %llu ms\n", time_end - time_start); | |||
| printf("single thread running time : %lu ms\n", time_end - time_start); | |||
| std::string output_path = "./test_data/pooling/avgpoolgradfp32_s2_dx_3_28_28_3.bin"; | |||
| auto res = lite::CompareRelativeOutput(out_data, output_path); | |||
| @@ -319,17 +332,22 @@ TEST_F(TestPoolingGradFp32, AvgPoolGradStride3Fp32) { | |||
| std::vector<lite::Tensor *> inputs = {&yt_tensor, &x_tensor}; | |||
| std::vector<lite::Tensor *> outputs = {&out_tensor}; | |||
| // ---------------------------------------- | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey pool_desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; | |||
| auto pool_creator = lite::KernelRegistry::GetInstance()->GetCreator(pool_desc); | |||
| auto kernel = pool_creator(inputs, outputs, reinterpret_cast<OpParameter *>(pool), NULL, pool_desc, nullptr); | |||
| auto kernel = pool_creator(inputs, outputs, reinterpret_cast<OpParameter *>(pool), &context, pool_desc, nullptr); | |||
| kernel->Init(); | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| kernel->Run(); | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| printf("single thread running time : %llu ms\n", time_end - time_start); | |||
| printf("single thread running time : %lu ms\n", time_end - time_start); | |||
| std::string output_path = "./test_data/pooling/avgpoolgradfp32_s3_dx_3_28_28_3.bin"; | |||
| auto res = lite::CompareRelativeOutput(out_data, output_path); | |||
| @@ -371,13 +389,13 @@ TEST_F(TestPoolingGradFp32, MaxPoolingGradFp32) { | |||
| auto output_data = new float[output_data_size]; | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| MaxPoolingGrad(in_data, dx_data, dy_data, output_data, pooling_param); | |||
| MaxPoolingGrad(in_data, dx_data, dy_data, output_data, pooling_param, 1); | |||
| } | |||
| int loop_count = 100; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| for (int i = 0; i < loop_count; i++) { | |||
| MaxPoolingGrad(in_data, dx_data, dy_data, output_data, pooling_param); | |||
| MaxPoolingGrad(in_data, dx_data, dy_data, output_data, pooling_param, 1); | |||
| } | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| auto cost = time_end - time_start; | |||
| @@ -435,10 +453,15 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradBatchFp32) { | |||
| auto out_data = static_cast<float *>(out_tensor.MutableData()); | |||
| std::vector<lite::Tensor *> maxpool_inputs = {&x_tensor, &y_tensor, &yt_tensor}; | |||
| std::vector<lite::Tensor *> maxpool_outputs = {&out_tensor}; | |||
| // ---------------------------------------- | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey maxpool_desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; | |||
| auto maxpool_creator = lite::KernelRegistry::GetInstance()->GetCreator(maxpool_desc); | |||
| auto kernel = maxpool_creator(maxpool_inputs, maxpool_outputs, reinterpret_cast<OpParameter *>(maxpool), NULL, | |||
| auto kernel = maxpool_creator(maxpool_inputs, maxpool_outputs, reinterpret_cast<OpParameter *>(maxpool), &context, | |||
| maxpool_desc, nullptr); | |||
| kernel->Init(); | |||
| @@ -446,7 +469,7 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradBatchFp32) { | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| kernel->Run(); | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| printf("single thread running time : %llu ms\n", time_end - time_start); | |||
| printf("single thread running time : %lu ms\n", time_end - time_start); | |||
| std::string output_path = "./test_data/pooling/maxpoolgradfp32_1_xgrad_3_28_28_3.bin"; | |||
| auto res = lite::CompareRelativeOutput(out_data, output_path); | |||
| @@ -505,10 +528,15 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradStride2Fp32) { | |||
| std::vector<lite::Tensor *> maxpool_inputs = {&x_tensor, &y_tensor, &yt_tensor}; | |||
| std::vector<lite::Tensor *> maxpool_outputs = {&out_tensor}; | |||
| // ---------------------------------------- | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey maxpool_desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; | |||
| auto maxpool_creator = lite::KernelRegistry::GetInstance()->GetCreator(maxpool_desc); | |||
| auto kernel = maxpool_creator(maxpool_inputs, maxpool_outputs, reinterpret_cast<OpParameter *>(maxpool), NULL, | |||
| auto kernel = maxpool_creator(maxpool_inputs, maxpool_outputs, reinterpret_cast<OpParameter *>(maxpool), &context, | |||
| maxpool_desc, nullptr); | |||
| kernel->Init(); | |||
| @@ -516,7 +544,7 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradStride2Fp32) { | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| kernel->Run(); | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| printf("single thread running time : %llu ms\n", time_end - time_start); | |||
| printf("single thread running time : %lu ms\n", time_end - time_start); | |||
| std::string output_path = "./test_data/pooling/maxpoolgradfp32_s2_xgrad_3_28_28_3.bin"; | |||
| auto res = lite::CompareRelativeOutput(out_data, output_path); | |||
| @@ -575,10 +603,15 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradStride3Fp32) { | |||
| std::vector<lite::Tensor *> maxpool_inputs = {&x_tensor, &y_tensor, &yt_tensor}; | |||
| std::vector<lite::Tensor *> maxpool_outputs = {&out_tensor}; | |||
| // ---------------------------------------- | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey maxpool_desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; | |||
| auto maxpool_creator = lite::KernelRegistry::GetInstance()->GetCreator(maxpool_desc); | |||
| auto kernel = maxpool_creator(maxpool_inputs, maxpool_outputs, reinterpret_cast<OpParameter *>(maxpool), NULL, | |||
| auto kernel = maxpool_creator(maxpool_inputs, maxpool_outputs, reinterpret_cast<OpParameter *>(maxpool), &context, | |||
| maxpool_desc, nullptr); | |||
| kernel->Init(); | |||
| @@ -586,7 +619,7 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradStride3Fp32) { | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| kernel->Run(); | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| printf("single thread running time : %llu ms\n", time_end - time_start); | |||
| printf("single thread running time : %lu ms\n", time_end - time_start); | |||
| std::string output_path = "./test_data/pooling/maxpoolgradfp32_s3_xgrad_3_28_28_3.bin"; | |||
| auto res = lite::CompareRelativeOutput(out_data, output_path); | |||
| @@ -59,9 +59,15 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { | |||
| grad_tensor.SetData(grad); | |||
| std::vector<lite::Tensor *> outputs = {&loss_tensor, &grad_tensor}; | |||
| lite::InnerContext context; | |||
| context.device_type_ = lite::DT_CPU; | |||
| context.thread_num_ = 1; | |||
| ASSERT_EQ(lite::RET_OK, context.Init()); | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SoftmaxCrossEntropy}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(sce_param), NULL, desc, nullptr); | |||
| auto kernel_obj = creator(inputs, outputs, reinterpret_cast<OpParameter *>(sce_param), &context, desc, nullptr); | |||
| mindspore::kernel::LiteKernel::AllocWorkspace(kernel_obj->GetWorkspaceSize()); | |||
| kernel_obj->Run(); | |||
| printf("==================total loss=================\n"); | |||
| @@ -92,6 +98,7 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { | |||
| y_tensor.SetData(nullptr); | |||
| loss_tensor.SetData(nullptr); | |||
| grad_tensor.SetData(nullptr); | |||
| mindspore::kernel::LiteKernel::FreeWorkspace(); | |||
| delete kernel_obj; | |||
| MS_LOG(INFO) << "SoftmaxCrossEntropyFp32 passed"; | |||
| } | |||
| @@ -21,13 +21,12 @@ | |||
| #include "mindspore/lite/include/context.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "common/common_test.h" | |||
| #include "mindspore/lite/src/kernel_registry.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/common/file_utils.h" | |||
| #include "src/common/file_utils_ext.h" | |||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_grad.h" | |||
| #include "mindspore/lite/nnacl/fp32_grad/softmax_grad.h" | |||
| #include "mindspore/lite/src/kernel_registry.h" | |||
| namespace mindspore { | |||
| class TestSoftmaxGradFp32 : public mindspore::CommonTest { | |||
| @@ -55,348 +54,6 @@ void InitSoftMaxParam(SoftmaxParameter *softmax_param, int axis, int n, int c, i | |||
| softmax_param->input_shape_[3] = w; | |||
| } | |||
| #if 0 // kernel testing | |||
| TEST_F(TestSoftmaxGradFp32, SoftmaxGradKernelAxis0) { | |||
| auto softmax_param = reinterpret_cast<SoftmaxParameter *>(malloc(sizeof(SoftmaxParameter))); | |||
| // set parameters | |||
| InitSoftMaxParam(softmax_param, 0); | |||
| std::vector<int> shape = {1, 9, 11, 12}; | |||
| size_t input_size; | |||
| std::string input_path = "./test_data/softmax/softmaxgrad_yinput.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| lite::tensor::Tensor input_tensor(TypeId::kNumberTypeFloat32, shape); | |||
| input_tensor.SetData(input_data); | |||
| std::string yt_path = "./test_data/softmax/softmaxgrad_yt_input.bin"; | |||
| auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); | |||
| lite::tensor::Tensor yt_tensor(TypeId::kNumberTypeFloat32, shape); | |||
| yt_tensor.SetData(yt_data); | |||
| // runtime part | |||
| printf("Calculating runtime cost...\n"); | |||
| uint64_t time_avg = 0; | |||
| auto out_data = new float[softmax_param->element_size_]; | |||
| lite::tensor::Tensor out_tensor(TypeId::kNumberTypeFloat32, shape); | |||
| out_tensor.SetData(out_data); | |||
| std::vector<lite::tensor::Tensor *> inputs = {&input_tensor, &yt_tensor}; | |||
| std::vector<lite::tensor::Tensor *> outputs = {&out_tensor}; | |||
| // float sum_data[6]; | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SoftMaxGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(softmax_param), NULL, desc, nullptr); | |||
| kernel->Init(); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel->Run(); | |||
| } | |||
| int loop_count = 3; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| for (int i = 0; i < loop_count; i++) { | |||
| kernel->Run(); | |||
| } | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| auto cost = time_end - time_start; | |||
| time_avg = cost / loop_count; | |||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||
| std::string output_path = "./test_data/softmax/softmaxgrad_out.bin"; | |||
| // auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| auto res = lite::CompareRelativeOutput(out_data, output_path); | |||
| EXPECT_EQ(res, 0); | |||
| delete[] input_data; | |||
| delete[] yt_data; | |||
| delete[] out_data; | |||
| input_tensor.SetData(nullptr); | |||
| yt_tensor.SetData(nullptr); | |||
| out_tensor.SetData(nullptr); | |||
| delete kernel; | |||
| // delete softmax_param; | |||
| MS_LOG(INFO) << "SoftmaxGradKernelAxis0 passed"; | |||
| } | |||
| TEST_F(TestSoftmaxGradFp32, SoftmaxGradKernelAxis1) { | |||
| auto softmax_param = reinterpret_cast<SoftmaxParameter *>(malloc(sizeof(SoftmaxParameter))); | |||
| // set parameters | |||
| InitSoftMaxParam(softmax_param, 1); | |||
| std::vector<int> shape = {1, 9, 11, 12}; | |||
| size_t input_size; | |||
| std::string input_path = "./test_data/softmax/softmaxgrad_1_yinput.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| lite::tensor::Tensor input_tensor(TypeId::kNumberTypeFloat32, shape); | |||
| input_tensor.SetData(input_data); | |||
| std::string yt_path = "./test_data/softmax/softmaxgrad_1_yt_input.bin"; | |||
| auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); | |||
| lite::tensor::Tensor yt_tensor(TypeId::kNumberTypeFloat32, shape); | |||
| yt_tensor.SetData(yt_data); | |||
| // runtime part | |||
| printf("Calculating runtime cost...\n"); | |||
| uint64_t time_avg = 0; | |||
| auto out_data = new float[softmax_param->element_size_]; | |||
| lite::tensor::Tensor out_tensor(TypeId::kNumberTypeFloat32, shape); | |||
| out_tensor.SetData(out_data); | |||
| std::vector<lite::tensor::Tensor *> inputs = {&input_tensor, &yt_tensor}; | |||
| std::vector<lite::tensor::Tensor *> outputs = {&out_tensor}; | |||
| // float sum_data[6]; | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SoftMaxGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(softmax_param), NULL, desc, nullptr); | |||
| kernel->Init(); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel->Run(); | |||
| } | |||
| int loop_count = 3; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| for (int i = 0; i < loop_count; i++) { | |||
| kernel->Run(); | |||
| } | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| auto cost = time_end - time_start; | |||
| time_avg = cost / loop_count; | |||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||
| std::string output_path = "./test_data/softmax/softmaxgrad_1_out.bin"; | |||
| // auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| auto res = lite::CompareRelativeOutput(out_data, output_path); | |||
| EXPECT_EQ(res, 0); | |||
| delete[] input_data; | |||
| delete[] yt_data; | |||
| delete[] out_data; | |||
| input_tensor.SetData(nullptr); | |||
| yt_tensor.SetData(nullptr); | |||
| out_tensor.SetData(nullptr); | |||
| delete kernel; | |||
| // delete softmax_param; | |||
| MS_LOG(INFO) << "SoftmaxGradKernelAxis1 passed"; | |||
| } | |||
| TEST_F(TestSoftmaxGradFp32, SoftmaxGradKernelAxis2) { | |||
| auto softmax_param = reinterpret_cast<SoftmaxParameter *>(malloc(sizeof(SoftmaxParameter))); | |||
| // set parameters | |||
| InitSoftMaxParam(softmax_param, 2); | |||
| std::vector<int> shape = {1, 9, 11, 12}; | |||
| size_t input_size; | |||
| std::string input_path = "./test_data/softmax/softmaxgrad_2_yinput.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| lite::tensor::Tensor input_tensor(TypeId::kNumberTypeFloat32, shape); | |||
| input_tensor.SetData(input_data); | |||
| std::string yt_path = "./test_data/softmax/softmaxgrad_2_yt_input.bin"; | |||
| auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); | |||
| lite::tensor::Tensor yt_tensor(TypeId::kNumberTypeFloat32, shape); | |||
| yt_tensor.SetData(yt_data); | |||
| // runtime part | |||
| printf("Calculating runtime cost...\n"); | |||
| uint64_t time_avg = 0; | |||
| auto out_data = new float[softmax_param->element_size_]; | |||
| lite::tensor::Tensor out_tensor(TypeId::kNumberTypeFloat32, shape); | |||
| out_tensor.SetData(out_data); | |||
| std::vector<lite::tensor::Tensor *> inputs = {&input_tensor, &yt_tensor}; | |||
| std::vector<lite::tensor::Tensor *> outputs = {&out_tensor}; | |||
| // float sum_data[6]; | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SoftMaxGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(softmax_param), NULL, desc, nullptr); | |||
| kernel->Init(); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel->Run(); | |||
| } | |||
| int loop_count = 3; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| for (int i = 0; i < loop_count; i++) { | |||
| kernel->Run(); | |||
| } | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| auto cost = time_end - time_start; | |||
| time_avg = cost / loop_count; | |||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||
| std::string output_path = "./test_data/softmax/softmaxgrad_2_out.bin"; | |||
| // auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| auto res = lite::CompareRelativeOutput(out_data, output_path); | |||
| EXPECT_EQ(res, 0); | |||
| delete[] input_data; | |||
| delete[] yt_data; | |||
| delete[] out_data; | |||
| input_tensor.SetData(nullptr); | |||
| yt_tensor.SetData(nullptr); | |||
| out_tensor.SetData(nullptr); | |||
| delete kernel; | |||
| // delete softmax_param; | |||
| MS_LOG(INFO) << "SoftmaxGradKernelAxis2 passed"; | |||
| } | |||
| TEST_F(TestSoftmaxGradFp32, SoftmaxGradKernelAxis3) { | |||
| auto softmax_param = reinterpret_cast<SoftmaxParameter *>(malloc(sizeof(SoftmaxParameter))); | |||
| // set parameters | |||
| InitSoftMaxParam(softmax_param, 3); | |||
| std::vector<int> shape = {1, 9, 11, 12}; | |||
| size_t input_size; | |||
| std::string input_path = "./test_data/softmax/softmaxgrad_3_yinput.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| lite::tensor::Tensor input_tensor(TypeId::kNumberTypeFloat32, shape); | |||
| input_tensor.SetData(input_data); | |||
| std::string yt_path = "./test_data/softmax/softmaxgrad_3_yt_input.bin"; | |||
| auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); | |||
| lite::tensor::Tensor yt_tensor(TypeId::kNumberTypeFloat32, shape); | |||
| yt_tensor.SetData(yt_data); | |||
| // runtime part | |||
| printf("Calculating runtime cost...\n"); | |||
| uint64_t time_avg = 0; | |||
| auto out_data = new float[softmax_param->element_size_]; | |||
| lite::tensor::Tensor out_tensor(TypeId::kNumberTypeFloat32, shape); | |||
| out_tensor.SetData(out_data); | |||
| std::vector<lite::tensor::Tensor *> inputs = {&input_tensor, &yt_tensor}; | |||
| std::vector<lite::tensor::Tensor *> outputs = {&out_tensor}; | |||
| // float sum_data[6]; | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SoftMaxGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(softmax_param), NULL, desc, nullptr); | |||
| kernel->Init(); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel->Run(); | |||
| } | |||
| int loop_count = 3; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| for (int i = 0; i < loop_count; i++) { | |||
| kernel->Run(); | |||
| } | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| auto cost = time_end - time_start; | |||
| time_avg = cost / loop_count; | |||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||
| std::string output_path = "./test_data/softmax/softmaxgrad_3_out.bin"; | |||
| // auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| auto res = lite::CompareRelativeOutput(out_data, output_path); | |||
| EXPECT_EQ(res, 0); | |||
| delete[] input_data; | |||
| delete[] yt_data; | |||
| delete[] out_data; | |||
| input_tensor.SetData(nullptr); | |||
| yt_tensor.SetData(nullptr); | |||
| out_tensor.SetData(nullptr); | |||
| delete kernel; | |||
| // delete softmax_param; | |||
| MS_LOG(INFO) << "SoftmaxGradKernelAxis3 passed"; | |||
| } | |||
| TEST_F(TestSoftmaxGradFp32, SoftmaxGradKernelAxisMinus1) { | |||
| auto softmax_param = reinterpret_cast<SoftmaxParameter *>(malloc(sizeof(SoftmaxParameter))); | |||
| // set parameters | |||
| InitSoftMaxParam(softmax_param, -1); | |||
| std::vector<int> shape = {1, 9, 11, 12}; | |||
| size_t input_size; | |||
| std::string input_path = "./test_data/softmax/softmaxgrad_-1_yinput.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| lite::tensor::Tensor input_tensor(TypeId::kNumberTypeFloat32, shape); | |||
| input_tensor.SetData(input_data); | |||
| std::string yt_path = "./test_data/softmax/softmaxgrad_-1_yt_input.bin"; | |||
| auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); | |||
| lite::tensor::Tensor yt_tensor(TypeId::kNumberTypeFloat32, shape); | |||
| yt_tensor.SetData(yt_data); | |||
| // runtime part | |||
| printf("Calculating runtime cost...\n"); | |||
| uint64_t time_avg = 0; | |||
| auto out_data = new float[softmax_param->element_size_]; | |||
| lite::tensor::Tensor out_tensor(TypeId::kNumberTypeFloat32, shape); | |||
| out_tensor.SetData(out_data); | |||
| std::vector<lite::tensor::Tensor *> inputs = {&input_tensor, &yt_tensor}; | |||
| std::vector<lite::tensor::Tensor *> outputs = {&out_tensor}; | |||
| // float sum_data[6]; | |||
| kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SoftMaxGrad}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(softmax_param), NULL, desc, nullptr); | |||
| kernel->Init(); | |||
| // warm up loop | |||
| for (int i = 0; i < 3; i++) { | |||
| kernel->Run(); | |||
| } | |||
| int loop_count = 3; | |||
| auto time_start = mindspore::lite::GetTimeUs(); | |||
| for (int i = 0; i < loop_count; i++) { | |||
| kernel->Run(); | |||
| } | |||
| auto time_end = mindspore::lite::GetTimeUs(); | |||
| auto cost = time_end - time_start; | |||
| time_avg = cost / loop_count; | |||
| printf("single thread running time : %f ms\n", time_avg / 1000.0f); | |||
| std::string output_path = "./test_data/softmax/softmaxgrad_-1_out.bin"; | |||
| // auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| auto res = lite::CompareRelativeOutput(out_data, output_path); | |||
| EXPECT_EQ(res, 0); | |||
| delete[] input_data; | |||
| delete[] yt_data; | |||
| delete[] out_data; | |||
| input_tensor.SetData(nullptr); | |||
| yt_tensor.SetData(nullptr); | |||
| out_tensor.SetData(nullptr); | |||
| delete kernel; | |||
| // delete softmax_param; | |||
| MS_LOG(INFO) << "SoftmaxGradKernelAxisMinus1 passed"; | |||
| } | |||
| #endif // kernel testing | |||
| TEST_F(TestSoftmaxGradFp32, SoftmaxGradAxis0) { | |||
| auto softmax_param = new SoftmaxParameter(); | |||
| // set parameters | |||
| @@ -0,0 +1,2 @@ | |||
| gçÅ%Q—D§~×Äå…Å7áOÅ � EÇ | |||
| Ät6�EŒÆØÅ£¢Eµ[‹ÅXÓ’DçžÄ'U1E°vD^Ü»ÄBÃEŒ¹)EÙôAšÑAE*O{Eš3‘Å4ÂÎEyšCÆÕÿñÄB�ùDÛ¥aÅxK´ÃFËDu®àÄRèîD='’EÕ×&D,N~EpjZÄœÊñEb®OA¥[îDÔvEpt¦ÂØršÄ…#¯C©„¡Dð¢ÆÜ+ÜEð�C"DZÅ3bŸEø‰Ä^[ÜÃ<k_Eú‚Å€CÅE-û7Ã.ºÅ†&\ÅÃÆEEÏÚ*źN]DZODJE«E�~ÇÅÀö#Æts>EnÄ@}5E M9Åž3±ÅÚcšÄJôíDšE(\âÃÐ}ôÄÛÚ.EžjÅõi±DؘNÅÆkýÄ|2Ä4 {E | |||