diff --git a/mindspore/lite/include/model.h b/mindspore/lite/include/model.h index 2a880ca0ae..eb96f4dbe6 100644 --- a/mindspore/lite/include/model.h +++ b/mindspore/lite/include/model.h @@ -69,6 +69,7 @@ class MS_API Model { /// \brief Free MetaGraph in MindSpore Lite Model. void FreeMetaGraph(); + ModelImpl *model_impl() {return model_impl_;} protected: ModelImpl *model_impl_ = nullptr; diff --git a/mindspore/lite/include/train_session.h b/mindspore/lite/include/train_session.h new file mode 100644 index 0000000000..724e112f0e --- /dev/null +++ b/mindspore/lite/include/train_session.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_SESSION_H_ +#define MINDSPORE_LITE_INCLUDE_TRAIN_SESSION_H_ +#include +#include +#include +// #include "include/lite_session.h" +#include "src/lite_session.h" + +namespace mindspore { +namespace lite { +class Model; +} +namespace lite::tensor { +class Tensor; +} +namespace session { + +class TrainSession : public lite::LiteSession { + public: + TrainSession(); + ~TrainSession() = default; + + int RunGraph(const session::KernelCallBack &before = nullptr, + const session::KernelCallBack &after = nullptr) override; + + int CompileGraph(lite::Model *model) override; + virtual void ReplaceOps(); + virtual void* ExportToBuf(void* buf, size_t* len) const; + + std::unordered_map> GetOutputs() const; + std::vector GetOutputsByName(const std::string &node_name) const; + + virtual void train(); + bool is_train() { return train_mode_ == true; } + virtual void eval(); + bool is_eval() { return train_mode_ == false; } + + protected: + bool train_mode_ = false; + lite::Model* model_ = nullptr; + std::unordered_map> ext_output_map_; + + + // private: +}; +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_LITE_INCLUDE_TRAIN_SESSION_H_ diff --git a/mindspore/lite/nnacl/activation_grad.c b/mindspore/lite/nnacl/fp32_grad/activation_grad.c similarity index 90% rename from mindspore/lite/nnacl/activation_grad.c rename to mindspore/lite/nnacl/fp32_grad/activation_grad.c index 7dcc11b558..d7b070473d 100644 --- a/mindspore/lite/nnacl/activation_grad.c +++ b/mindspore/lite/nnacl/fp32_grad/activation_grad.c @@ -13,9 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "nnacl/activation_grad.h" -int ReluGrad(float *src0, float *src1, int length, float *dst) { +#include +#include "nnacl/op_base.h" +#include "nnacl/fp32/arithmetic.h" +#include "nnacl/fp32_grad/activation_grad.h" +#include "nnacl/errorcode.h" + +inline int ReluGrad(float *src0, float *src1, int length, float *dst) { for (int i = 0; i < length; ++i) { dst[i] = src1[i] > 0 ? 1.0f : 0.0f; } diff --git a/mindspore/lite/nnacl/activation_grad.h b/mindspore/lite/nnacl/fp32_grad/activation_grad.h similarity index 100% rename from mindspore/lite/nnacl/activation_grad.h rename to mindspore/lite/nnacl/fp32_grad/activation_grad.h diff --git a/mindspore/lite/nnacl/fp32_grad/batch_norm.c b/mindspore/lite/nnacl/fp32_grad/batch_norm.c index fc9893de2d..76cb832b66 100644 --- a/mindspore/lite/nnacl/fp32_grad/batch_norm.c +++ b/mindspore/lite/nnacl/fp32_grad/batch_norm.c @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include #include +#include #include "nnacl/fp32_grad/batch_norm.h" -static void sumSpatialBatch(const float *in, int size, int ch, float *out) { +void sumSpatialBatch(const float *in, int size, int ch, float *out) { memset(out, 0, ch * sizeof(float)); for (int i = 0; i < size; i++) { const float *ptr = in + i * ch; @@ -32,49 +32,53 @@ void scaleBias(const float *scales, int batch, int n, int size, float *output) { for (int c = 0; c < n; c++) output[i * n + c] *= scales[c]; } -void normalize(const float *x, const float *mean, const float *variance, float eps, int batch, int filters, int spatial, +void normalize(const float *x, const float *mean, const float *invar, int batch, int filters, int spatial, float *out) { int b, f, i; for (b = 0; b < batch; ++b) { for (i = 0; i < spatial; ++i) { for (f = 0; f < filters; ++f) { int index = b * filters * spatial + i * filters + f; - out[index] = (x[index] - mean[f]) / (sqrt(variance[f]) + eps); + out[index] = (x[index] - mean[f]) * invar[f]; } } } } -void backwardScale(const float *x_norm, const float *delta, int batch, int n, int size, float *scale_updates) { +void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch, + int n, int size, float *scale_updates) { int i, b, f; memset(scale_updates, 0, n * sizeof(float)); for (b = 0; b < batch; ++b) { for (i = 0; i < size; ++i) { for (f = 0; f < n; ++f) { int index = (b * size + i) * n + f; - scale_updates[f] += delta[index] * x_norm[index]; + float x_norm = (x[index] - mean[f]) * invar[f]; + scale_updates[f] += delta[index] * x_norm; } } } } -void meanVar(const float *in, int batch, int spatial, int ch, float *mean, float *var) { +void meanVar(const float *in, int batch, int spatial, int ch, float eps, float *mean, float *invar) { float N = batch * spatial; sumSpatialBatch(in, N, ch, mean); - for (int f = 0; f < ch; ++f) mean[f] /= N; - memset(var, 0, ch * sizeof(float)); - for (int i = 0; i < N; i++) { - for (int f = 0; f < ch; f++) { - float x = in[i * ch + f]; - var[f] += (x - mean[f]) * (x - mean[f]); + for (int f = 0; f < ch; ++f) { + mean[f] /= N; + } + for (int f=0; f< ch; f++) { + float tvar = 0; + for (int i =0; i< N; i++) { + float x = in[i*ch +f]; + tvar += (x-mean[f]) *(x-mean[f]); } + invar[f] = 1.0f/(sqrt(tvar/N+eps)); } - for (int f = 0; f < ch; f++) var[f] /= N; } -void meanDelta(float *yt, int size, int ch, float eps, float *variance, float *mean_delta) { +void meanDelta(float *yt, int size, int ch, float *invar, float *mean_delta) { sumSpatialBatch(yt, size, ch, mean_delta); - for (int i = 0; i < ch; i++) mean_delta[i] *= -1.f / sqrt((variance[i] + eps)); + for (int i = 0; i < ch; i++) mean_delta[i] *= -invar[i]; } void meanAdd(const float *x, const float *mean, const float *variance_delta, int batch, int filters, int spatial, @@ -93,8 +97,8 @@ void meanAdd(const float *x, const float *mean, const float *variance_delta, int } } -void varianceDelta(const float *x, const float *delta, const float *mean, const float *variance, int batch, int filters, - int spatial, float eps, float *variance_delta) { +void varianceDelta(const float *x, const float *delta, const float *mean, const float *invar, int batch, int filters, + int spatial, float *variance_delta) { int i, k; memset(variance_delta, 0, filters * sizeof(float)); for (k = 0; k < batch * spatial; k++) { @@ -103,16 +107,16 @@ void varianceDelta(const float *x, const float *delta, const float *mean, const variance_delta[i] += delta[index] * (x[index] - mean[i]); } } - for (i = 0; i < filters; i++) variance_delta[i] *= -.5 * pow(variance[i] + eps, (-3.f / 2.f)); + for (i = 0; i < filters; i++) variance_delta[i] *= -.5 * 1.0f/(invar[i]*invar[i]*invar[i]); } -void NormalizeDelta(const float *x, const float *mean, const float *variance, const float *mean_delta, - const float *variance_delta, int batch, int filters, int spatial, float eps, float *delta) { +void NormalizeDelta(const float *x, const float *mean, const float *invar, const float *mean_delta, + const float *variance_delta, int batch, int filters, int spatial, float *delta) { int f, k; for (k = 0; k < batch * spatial; k++) { for (f = 0; f < filters; f++) { int index = k * filters + f; - delta[index] = delta[index] * 1. / (sqrt(variance[f] + eps)) + + delta[index] = delta[index] * invar[f] + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f] / (spatial * batch); } diff --git a/mindspore/lite/nnacl/fp32_grad/batch_norm.h b/mindspore/lite/nnacl/fp32_grad/batch_norm.h index 488ef98980..1603aa0c07 100644 --- a/mindspore/lite/nnacl/fp32_grad/batch_norm.h +++ b/mindspore/lite/nnacl/fp32_grad/batch_norm.h @@ -17,28 +17,33 @@ #ifndef MINDSPORE_LITE_NNACL_FP32_BATCH_NORM_H_ #define MINDSPORE_LITE_NNACL_FP32_BATCH_NORM_H_ -typedef struct bnParameter { - int batch; - int channels; - int spatial; - float eps; -} bnParameter; +#include "nnacl/op_base.h" + +typedef struct BNGradParameter { + OpParameter op_parameter_; + float epsilon_; + float momentum_; +} BNGradParameter; #ifdef __cplusplus extern "C" { #endif + + +void sumSpatialBatch(const float *in, int size, int ch, float *out); void scaleBias(const float *scales, int batch, int n, int size, float *output); -void normalize(const float *x, const float *mean, const float *variance, float eps, int batch, int filters, int spatial, +void normalize(const float *x, const float *mean, const float *invar, int batch, int filters, int spatial, float *out); -void backwardScale(const float *x_norm, const float *delta, int batch, int n, int size, float *scale_updates); -void meanVar(const float *in, int batch, int size, int ch, float *mean, float *var); -void meanDelta(float *yt, int size, int ch, float eps, float *variance, float *mean_delta); -void varianceDelta(const float *x, const float *delta, const float *mean, const float *variance, int batch, int ch, - int spatial, float eps, float *variance_delta); +void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch, + int n, int size, float *scale_updates); +void meanVar(const float *in, int batch, int size, int ch, float eps, float *mean, float *invar); +void meanDelta(float *yt, int size, int ch, float *invar, float *mean_delta); +void varianceDelta(const float *x, const float *delta, const float *mean, const float *invar, int batch, int ch, + int spatial, float *variance_delta); void meanAdd(const float *x, const float *mean, const float *variance_delta, int batch, int filters, int spatial, float *mean_add, float *mean_delta); -void NormalizeDelta(const float *x, const float *mean, const float *variance, const float *mean_delta, - const float *variance_delta, int batch, int filters, int spatial, float eps, float *delta); +void NormalizeDelta(const float *x, const float *mean, const float *invar, const float *mean_delta, + const float *variance_delta, int batch, int filters, int spatial, float *delta); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp32_grad/pack_ext.c b/mindspore/lite/nnacl/fp32_grad/pack_ext.c index 48665e83f2..fd11c3da8b 100644 --- a/mindspore/lite/nnacl/fp32_grad/pack_ext.c +++ b/mindspore/lite/nnacl/fp32_grad/pack_ext.c @@ -125,9 +125,9 @@ void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param } void col2im_hwc(const float *data_col, float *data_im, ConvParameter *conv_param) { - const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_w_; + const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_l_; // const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_; - const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_h_; + const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_u_; // const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_; const int stride_h = conv_param->stride_h_; diff --git a/mindspore/lite/nnacl/fp32_grad/pooling_grad.c b/mindspore/lite/nnacl/fp32_grad/pooling_grad.c index a2a2001288..58423c192c 100644 --- a/mindspore/lite/nnacl/fp32_grad/pooling_grad.c +++ b/mindspore/lite/nnacl/fp32_grad/pooling_grad.c @@ -13,7 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include +#include #include "nnacl/fp32_grad/pooling_grad.h" void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param) { @@ -31,33 +32,37 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter int output_batch = pooling_param->output_batch_; const float *inPtr = NULL; - for (int i = 0; i < output_h * output_w * channel * output_batch; i++) output_ptr[i] = 0.0; + // for (int i = 0; i < output_h * output_w * channel * output_batch; i++) output_ptr[i] = 0.0; + for (int i = 0; i < in_h * in_w * channel * output_batch; i++) output_ptr[i] = 0.0; float kk = (float)(win_h * win_w); for (uint16_t ib = 0; ib < output_batch; ib++) { float *out; - out = &output_ptr[(ib * output_h * output_w)]; - inPtr = (float *)(&input_ptr[(ib * in_h * in_w)]); + // out = &output_ptr[(ib * output_h * output_w)]; + out = &output_ptr[(ib * in_h * in_w * channel)]; + // inPtr = (float *)(&input_ptr[(ib * in_h * in_w)]); + inPtr = (float *)(&input_ptr[(ib * output_h * output_w * channel)]); if (1) { // in->layout() == Tensor::nhwc) // iterate over yt - for (uint16_t yh = 0; yh < in_h; yh++) { - for (uint16_t yw = 0; yw < in_w; yw++) { + for (uint16_t yh = 0; yh < output_h; yh++) { + for (uint16_t yw = 0; yw < output_w; yw++) { for (uint16_t ic = 0; ic < channel; ic++) { - int idx = (yw + yh * in_w) * channel + ic; // (ic*in_h*in_w) + (in_w*yh) + yw; + int idx = (yw + yh * output_w) * channel + ic; // (ic*in_h*in_w) + (in_w*yh) + yw; float delta = inPtr[idx] / kk; for (int32_t kh = 0; kh < win_h; kh++) { int xh = yh * stride_h + kh - pad_h; - if ((xh < 0) || (xh >= output_h)) { + if ((xh < 0) || (xh >= in_h)) { continue; } for (int32_t kw = 0; kw < win_w; kw++) { int xw = yw * stride_w + kw - pad_w; - if ((xw < 0) || (xw >= output_w)) { + if ((xw < 0) || (xw >= in_w)) { continue; } - // out[(ic*output_h*output_w) + (xh*output_w) + xw] += delta; - out[(xw + output_w * xh) * channel + ic] += delta; + + // out[(xw + output_w * xh) * channel + ic] += delta; + out[(xw + in_w * xh) * channel + ic] += delta; } } } @@ -66,21 +71,22 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter } else { // nchw for (uint16_t ic = 0; ic < channel; ic++) { // iterate over yt - for (uint16_t yh = 0; yh < in_h; yh++) { - for (uint16_t yw = 0; yw < in_w; yw++) { - int idx = (ic * in_h * in_w) + (in_w * yh) + yw; + for (uint16_t yh = 0; yh < output_h; yh++) { + for (uint16_t yw = 0; yw < output_w; yw++) { + int idx = (ic * output_h * output_w) + (output_w * yh) + yw; float delta = inPtr[idx] / kk; for (int32_t kh = 0; kh < win_h; kh++) { int xh = yh * stride_h + kh - pad_h; - if ((xh < 0) || (xh >= output_h)) { + if ((xh < 0) || (xh >= in_h)) { continue; } for (int32_t kw = 0; kw < win_w; kw++) { int xw = yw * stride_w + kw - pad_w; - if ((xw < 0) || (xw >= output_w)) { + if ((xw < 0) || (xw >= in_w)) { continue; } - out[(ic * output_h * output_w) + (xh * output_w) + xw] += delta; + // out[(ic * output_h * output_w) + (xh * output_w) + xw] += delta; + out[(ic * in_h * in_w) + (xh * in_w) + xw] += delta; } } } @@ -90,7 +96,14 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter } } -void MaxPoolingGrad(const float *dy, const int *indices, float *output_ptr, PoolingParameter *pooling_param) { +void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr, + PoolingParameter *pooling_param) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; int channel = pooling_param->input_channel_; int in_w = pooling_param->input_w_; int in_h = pooling_param->input_h_; @@ -98,38 +111,73 @@ void MaxPoolingGrad(const float *dy, const int *indices, float *output_ptr, Pool int output_h = pooling_param->output_h_; int output_batch = pooling_param->output_batch_; - int out_img_size = - output_h * output_w; // Emir -- in original code this varible is calculated according to input size ?? - int ind_img_size = in_h * in_w; - // const int w_pad = (output_w + pad_w + pad_w); + const float *inPtr; + const float *dyPtr; + + for (int i = 0; i < in_h * in_w * channel * output_batch; i++) output_ptr[i] = 0.0; + + for (uint16_t ib = 0; ib < output_batch; ib++) { + float *out; + out = &output_ptr[(ib * in_h * in_w * channel)]; + inPtr = (const float *)(&input_ptr[(ib * in_h * in_w * channel)]); + dyPtr = (const float *)(&dy_ptr[(ib * output_h * output_w * channel)]); - for (int i = 0; i < output_h * output_w * channel * output_batch; i++) output_ptr[i] = 0.0; + if (1) { // nhwc + for (uint16_t yh = 0; yh < output_h; yh++) { + for (uint16_t yw = 0; yw < output_w; yw++) { + for (uint16_t ic = 0; ic < channel; ic++) { + int idx = (yw + yh * output_w) * channel + ic; - const float *yt = (const float *)(dy); - const int *pos = (const int *)(indices); - float *out = NULL; + float delta = dyPtr[idx]; + float max_val = -FLT_MAX; + int max_idx = 0; + for (int32_t kh = 0; kh < win_h; kh++) { + int xh = yh * stride_h + kh - pad_h; + if ((xh < 0) || (xh >= in_h)) { + continue; + } + for (int32_t kw = 0; kw < win_w; kw++) { + int xw = yw * stride_w + kw - pad_w; + if ((xw < 0) || (xw >= in_w)) { + continue; + } - if (1) { // grads->layout() == Tensor::nhwc) - for (int ib = 0; ib < output_batch; ib++) { - out = &(output_ptr[ib * output_w * output_w * channel]); - for (int ix = 0; ix < ind_img_size; ix++) { - for (int cix = 0; cix < channel; cix++) { - int idx = (*pos) * channel + cix; - out[idx] += *yt; - pos++; - yt++; + if (inPtr[(xw + in_w * xh) * channel + ic] > max_val) { + max_val = inPtr[(xw + in_w * xh) * channel + ic]; + max_idx = (xw + in_w * xh) * channel + ic; + } + } + } + out[max_idx] += delta; + } } } - } - } else { - for (int ib = 0; ib < output_batch; ib++) { - out = &output_ptr[(ib * out_img_size)]; - for (int cix = 0; cix < channel; cix++) { - for (int ix = 0; ix < ind_img_size; ix++) { - int idx = cix * output_h * output_w + *pos; // cord_y*output_w + cord_x; - out[idx] += *yt; - pos++; - yt++; + } else { // nchw + for (uint16_t yh = 0; yh < output_h; yh++) { + for (uint16_t yw = 0; yw < output_w; yw++) { + for (uint16_t ic = 0; ic < channel; ic++) { + int idx = (ic * output_h * output_w) + (output_w * yh) + yw; + float delta = dyPtr[idx]; + float max_val = -FLT_MAX; + int max_idx = 0; + for (int32_t kh = 0; kh < win_h; kh++) { + int xh = yh * stride_h + kh - pad_h; + if ((xh < 0) || (xh >= in_h)) { + continue; + } + for (int32_t kw = 0; kw < win_w; kw++) { + int xw = yw * stride_w + kw - pad_w; + if ((xw < 0) || (xw >= in_w)) { + continue; + } + if (inPtr[(ic * in_h * in_w) + (xh * in_w) + xw] > max_val) { + max_val = inPtr[(ic * in_h * in_w) + (xh * in_w) + xw]; + max_idx = (ic * in_h * in_w) + (xh * in_w) + xw; + } + } + } + out[max_idx] += delta; + } } } } diff --git a/mindspore/lite/nnacl/fp32_grad/pooling_grad.h b/mindspore/lite/nnacl/fp32_grad/pooling_grad.h index e8ba74b48d..4d27e21d7b 100644 --- a/mindspore/lite/nnacl/fp32_grad/pooling_grad.h +++ b/mindspore/lite/nnacl/fp32_grad/pooling_grad.h @@ -23,7 +23,9 @@ extern "C" { #endif void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param); -void MaxPoolingGrad(const float *dy, const int *indices_ptr, float *output_ptr, PoolingParameter *pooling_param); +// void MaxPoolingGrad(const float *dy, const int *indices_ptr, float *output_ptr, PoolingParameter *pooling_param); +void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr, + PoolingParameter *pooling_param); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp32_grad/reduce_grad.c b/mindspore/lite/nnacl/fp32_grad/reduce_grad.c index 332677013b..6963969817 100644 --- a/mindspore/lite/nnacl/fp32_grad/reduce_grad.c +++ b/mindspore/lite/nnacl/fp32_grad/reduce_grad.c @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#include #include "nnacl/fp32_grad/reduce_grad.h" -static inline bool NextIndex(const int num_dims, const int *dims, int *current) { +static inline int NextIndex(const int num_dims, const int *dims, int *current) { int carry = 1; for (int idx = num_dims - 1; idx >= 0; --idx) { int current_val = current[idx] + carry; @@ -45,10 +45,10 @@ static inline size_t GetOutputOffset(const int num_dims, const int *dims, const size_t offset = 0; for (int idx = 0; idx < num_dims; ++idx) { // if we need to skip this axis - bool is_axis = false; + int is_axis = 0; for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) { if (idx == axes[axis_idx]) { - is_axis = true; + is_axis = 1; break; } } @@ -101,10 +101,10 @@ float ReduceMeanAll(const float *src, int size) { void ReduceSumByAxes(const float *input, const int *input_dims, float *output, const int *output_dims, int num_dims) { int num_outputs = 1; - int same_shape = true; + int same_shape = 1; for (int idx = 0; idx < num_dims; ++idx) { num_outputs *= output_dims[idx]; - if (output_dims[idx] != input_dims[idx]) same_shape = false; + if (output_dims[idx] != input_dims[idx]) same_shape = 0; } if (same_shape) { memcpy(output, input, num_outputs * sizeof(float)); diff --git a/mindspore/lite/nnacl/fp32_grad/reduce_grad.h b/mindspore/lite/nnacl/fp32_grad/reduce_grad.h index 814fdace08..52c633c83d 100644 --- a/mindspore/lite/nnacl/fp32_grad/reduce_grad.h +++ b/mindspore/lite/nnacl/fp32_grad/reduce_grad.h @@ -17,8 +17,7 @@ #ifndef MINDSPORE_LITE_NNACL_FP32_REDUCE_GRAD_H_ #define MINDSPORE_LITE_NNACL_FP32_REDUCE_GRAD_H_ -#include -#include +#include #ifdef __cplusplus extern "C" { diff --git a/mindspore/lite/nnacl/fp32_grad/softmax_grad.h b/mindspore/lite/nnacl/fp32_grad/softmax_grad.h index 726b6ae271..907810982d 100644 --- a/mindspore/lite/nnacl/fp32_grad/softmax_grad.h +++ b/mindspore/lite/nnacl/fp32_grad/softmax_grad.h @@ -20,7 +20,7 @@ #include "nnacl/op_base.h" typedef struct SoftmaxCrossEntropyParameter { - OpParameter op_parameter; + OpParameter op_parameter_; int32_t batch_size_; unsigned int number_of_classes_; int n_dim_; diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index d9c7a90444..d2d7b806e8 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -178,8 +178,8 @@ union PrimitiveType { Conv2DGradFilter, Conv2DGradInput, PoolingGrad, - BNGradInput, - OptMomentum, + BNGrad, + ApplyMomentum, BiasGrad, SoftmaxCrossEntropy, AddGrad, @@ -190,6 +190,7 @@ union PrimitiveType { ActivationGrad, PriorBox, SpaceToBatchND, + Depend, Return, MakeTuple, ToFormat, diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 1519d0a80f..eccf905448 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -149,7 +149,8 @@ table Activation { alpha: float = 0.2; } table ActivationGrad { - type: ActivationGradType = 0; + type: ActivationType = 0; + alpha: float = 0.2; } @@ -230,6 +231,9 @@ table SoftmaxCrossEntropy { axis: [int]; } +table make_tuple { +} + table PoolingGrad { format: Format = 0; @@ -390,10 +394,11 @@ table DeConv2D { hasBias: bool = false; activationType: ActivationType = 0; } -table BNGradInput { +table BNGrad { eps : float; - channels: int; + momentum: float; } + table Scale { axis: int; } @@ -841,7 +846,10 @@ table SquaredDifference { table TupleGetItem { } -table OptMomentum { +table ApplyMomentum { + gradientScale: float; + useLocking: bool; + useNesterov: bool; } @@ -884,6 +892,10 @@ table ToFormat { dstT: int; } + +table Depend { +} + table Return { } diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index eda8d8c089..036f0d756d 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -27,7 +27,7 @@ set(LITE_SRC ) if (SUPPORT_GPU) - set(LITE_SRC +set(LITE_SRC ${LITE_SRC} ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/subgraph_opencl_kernel.cc ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/utils.cc @@ -36,6 +36,24 @@ if (SUPPORT_GPU) ${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_runtime.cc ${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_wrapper.cc ) +endif() + + +if (SUPPORT_TRAIN) + set(ANF_SRC + ${ANF_SRC} + + ) + set(PASS_SRC) + set(LITE_SRC + ${LITE_SRC} + ${ANF_SRC} + # ${CMAKE_CURRENT_SOURCE_DIR}/train/ops/train_ops.cc + ${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc + ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc + ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc + ) + endif () file(GLOB_RECURSE C_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc) diff --git a/mindspore/lite/src/common/file_utils.cc b/mindspore/lite/src/common/file_utils.cc index c703c2288e..6cbb7b8dfc 100644 --- a/mindspore/lite/src/common/file_utils.cc +++ b/mindspore/lite/src/common/file_utils.cc @@ -110,6 +110,7 @@ int CompareOutputData(float *output_data, float *correct_data, int data_size) { } } error /= data_size; + if (error > 0.0001) { printf("has accuracy error!\n"); printf("%f\n", error); @@ -118,12 +119,14 @@ int CompareOutputData(float *output_data, float *correct_data, int data_size) { return 0; } -void CompareOutput(float *output_data, std::string file_path) { +int CompareOutput(float *output_data, std::string file_path) { size_t output_size; auto ground_truth = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); size_t output_num = output_size / sizeof(float); printf("output num : %zu\n", output_num); - CompareOutputData(output_data, ground_truth, output_num); + int res = CompareOutputData(output_data, ground_truth, output_num); + delete [] ground_truth; + return res; } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/common/file_utils.h b/mindspore/lite/src/common/file_utils.h index ff1ec03e64..03fdf5360d 100644 --- a/mindspore/lite/src/common/file_utils.h +++ b/mindspore/lite/src/common/file_utils.h @@ -47,7 +47,7 @@ void WriteToTxt(const std::string& file_path, void *data, size_t element_size) { int WriteToBin(const std::string& file_path, void *data, size_t size); int CompareOutputData(float *output_data, float *correct_data, int data_size); -void CompareOutput(float *output_data, std::string file_path); +int CompareOutput(float *output_data, std::string file_path); std::string GetAndroidPackageName(); std::string GetAndroidPackagePath(); diff --git a/mindspore/lite/src/common/file_utils_ext.cc b/mindspore/lite/src/common/file_utils_ext.cc index cdaa337e23..ade264d7b7 100644 --- a/mindspore/lite/src/common/file_utils_ext.cc +++ b/mindspore/lite/src/common/file_utils_ext.cc @@ -47,7 +47,9 @@ int CompareRelativeOutput(float *output_data, std::string file_path) { auto ground_truth = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); size_t output_num = output_size / sizeof(float); std::cout << "output num : " << output_num << "\n"; - return CompareOutputRelativeData(output_data, ground_truth, output_num); + int res = CompareOutputRelativeData(output_data, ground_truth, output_num); + delete [] ground_truth; + return res; } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/executor.cc b/mindspore/lite/src/executor.cc index 355845f91c..3877f8980f 100644 --- a/mindspore/lite/src/executor.cc +++ b/mindspore/lite/src/executor.cc @@ -39,6 +39,10 @@ int Executor::Run(std::vector &in_tensors, std::vectorSetRefCount(out_tensor->RefCount() + 1); + } + for (auto *kernel : kernels) { MS_ASSERT(nullptr != kernel); @@ -48,6 +52,8 @@ int Executor::Run(std::vector &in_tensors, std::vectorname(); } } + // JBDEBUG + // std::cout << "executing kernel " << kernel->name() << "\n"; auto ret = kernel->Run(); if (0 != ret) { MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name(); diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index 713257ac2d..6a166024cb 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -27,7 +27,6 @@ #include "src/ir/tensor.h" #include "include/errorcode.h" - // using mindspore::kernel::AddressPtr; namespace mindspore::kernel { using mindspore::lite::RET_ERROR; diff --git a/mindspore/lite/src/model.cc b/mindspore/lite/src/model.cc index 49b9b49203..66ad91a3e4 100644 --- a/mindspore/lite/src/model.cc +++ b/mindspore/lite/src/model.cc @@ -112,11 +112,11 @@ int ModelImpl::BuildOps() { Model *Model::Import(const char *model_buf, size_t size) { auto model = new Model(); + model->model_impl_ = ModelImpl::Import(model_buf, size); if (model_buf == nullptr) { MS_LOG(ERROR) << "model buf is null"; return nullptr; } - model->model_impl_ = ModelImpl::Import(model_buf, size); if (model->model_impl_ == nullptr) { MS_LOG(ERROR) << "model impl is null"; return nullptr; diff --git a/mindspore/lite/src/ops/activation_grad.cc b/mindspore/lite/src/ops/activation_grad.cc index a82479a6a6..0554e53040 100644 --- a/mindspore/lite/src/ops/activation_grad.cc +++ b/mindspore/lite/src/ops/activation_grad.cc @@ -20,11 +20,11 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE int ActivationGrad::GetType() const { return this->primitive_->value.AsActivationGrad()->type; } - +float ActivationGrad::GetAlpha() const { return this->primitive_->value.AsActivationGrad()->alpha; } void ActivationGrad::SetType(int type) { - this->primitive_->value.AsActivationGrad()->type = (schema::ActivationGradType)type; + this->primitive_->value.AsActivationGrad()->type = (schema::ActivationType)type; } - +void ActivationGrad::SetAlpha(float alpha) { this->primitive_->value.AsActivationGrad()->alpha = alpha; } #else int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); @@ -40,7 +40,7 @@ int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flat return RET_OK; } int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); } - +float ActivationGrad::GetAlpha() const { return this->primitive_->value_as_ActivationGrad()->alpha(); } #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/activation_grad.h b/mindspore/lite/src/ops/activation_grad.h index f4461d30c2..463043bd7b 100644 --- a/mindspore/lite/src/ops/activation_grad.h +++ b/mindspore/lite/src/ops/activation_grad.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_ -#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_ACTIVATION_GRAD_H_ +#define MINDSPORE_LITE_SRC_OPS_ACTIVATION_GRAD_H_ #include #include @@ -32,13 +32,15 @@ class ActivationGrad : public PrimitiveC { ActivationGrad() = default; explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetType(int type); + void SetAlpha(float alpha); #else ActivationGrad() = default; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int GetType() const; + float GetAlpha() const; }; } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_ +#endif // MINDSPORE_LITE_SRC_OPS_ACTIVATION_GRAD_H_ diff --git a/mindspore/lite/src/ops/apply_momentum.cc b/mindspore/lite/src/ops/apply_momentum.cc new file mode 100644 index 0000000000..b50716b1eb --- /dev/null +++ b/mindspore/lite/src/ops/apply_momentum.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/ops/apply_momentum.h" +namespace mindspore { +namespace lite { + + +#ifdef PRIMITIVE_WRITEABLE + +#else +int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_ApplyMomentum(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_ApplyMomentum return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateApplyMomentum(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ActivationGrad, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif + +int ApplyMomentum::InferShape(std::vector inputs, std::vector outputs) { + if (5 != inputs.size()) { + MS_LOG(ERROR) << "ApplyMomentum should have at 5 input tensors"; + return RET_ERROR; + } + // if (outputs.empty()) { + // MS_LOG(ERROR) << "ApplyMomentumCPUKernel error input output size!"; + // return RET_ERROR; + // } + + if (inputs[0]->ElementsNum() != inputs[1]->ElementsNum() || inputs[0]->ElementsNum() != inputs[3]->ElementsNum() || + inputs[2]->ElementsNum() != 1 || inputs[4]->ElementsNum() != 1) { + MS_LOG(ERROR) << "error input data size!"; + return RET_ERROR; + } + if (!outputs.empty()) { + auto *out = outputs.front(); + MS_ASSERT(out != nullptr); + out->set_data_type(inputs[0]->data_type()); + out->SetFormat(inputs[0]->GetFormat()); + } + + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/apply_momentum.h b/mindspore/lite/src/ops/apply_momentum.h new file mode 100644 index 0000000000..77ecf588d9 --- /dev/null +++ b/mindspore/lite/src/ops/apply_momentum.h @@ -0,0 +1,44 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPS_APPLY_MOMENTUM_H_ +#define MINDSPORE_LITE_SRC_OPS_APPLY_MOMENTUM_H_ + +#include +#include +#include +#include "ir/dtype/type_id.h" +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class ApplyMomentum : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(ApplyMomentum, PrimitiveC); + ApplyMomentum() = default; + explicit ApplyMomentum(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#else + ApplyMomentum() = default; + + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_OPS_APPLY_MOMENTUM_H_ diff --git a/mindspore/lite/src/ops/arithmetic_grad.cc b/mindspore/lite/src/ops/arithmetic_grad.cc new file mode 100644 index 0000000000..ee57bb6443 --- /dev/null +++ b/mindspore/lite/src/ops/arithmetic_grad.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/arithmetic_grad.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore { +namespace lite { +int ArithmeticGrad::InferShape(std::vector inputs_, + std::vector outputs_) { + if (inputs_.size() != 3) { + MS_LOG(ERROR) << "The number of input must be 3"; + return RET_ERROR; + } + if (outputs_.size() != 2) { + MS_LOG(ERROR) << "The number of output must be 2"; + return RET_ERROR; + } + auto dy = inputs_[0]; + auto x1 = inputs_[1]; + auto x2 = inputs_[2]; + auto dx1 = outputs_[0]; + auto dx2 = outputs_[1]; + + MS_ASSERT(dy != nullptr); + MS_ASSERT(x1 != nullptr); + MS_ASSERT(x2 != nullptr); + MS_ASSERT(dx1 != nullptr); + MS_ASSERT(dx2 != nullptr); + + auto inShape0 = x1->shape(); + auto inShape1 = x2->shape(); + auto outShape = dy->shape(); + + if ((Type() == schema::PrimitiveType_AddGrad) || (Type() == schema::PrimitiveType_SubGrad)) { + ndim_ = outShape.size(); + auto fillDimNum0 = outShape.size() - inShape0.size(); + auto fillDimNum1 = outShape.size() - inShape1.size(); + int j0 = 0; + int j1 = 0; + for (unsigned int i = 0; i < outShape.size(); i++) { + x1_shape_[i] = (i < fillDimNum0) ? 1 : inShape0[j0++]; + x2_shape_[i] = (i < fillDimNum1) ? 1 : inShape1[j1++]; + dy_shape_[i] = outShape[i]; + } + } else { + // if (inShape0.size() < inShape1.size()) + if (dx1->ElementsNum() < dx2->ElementsNum()) { + ndim_ = inShape1.size(); + auto fillDimNum = inShape1.size() - inShape0.size(); // This will not work for batch! + int j = 0; + for (unsigned int i = 0; i < inShape1.size(); i++) { + if (i < fillDimNum) { + x2_shape_[i] = 1; + } else { + x2_shape_[i] = inShape0[j++]; + } + x1_shape_[i] = inShape1[i]; + dy_shape_[i] = outShape[i]; + } + } else if (dx2->ElementsNum() < dx1->ElementsNum()) { // if (inShape0.size() > inShape1.size()) + ndim_ = inShape0.size(); + broadcasting_ = true; + ndim_ = inShape0.size(); + int j = 0; + auto fillDimNum = inShape0.size() - inShape1.size(); + for (unsigned int i = 0; i < inShape0.size(); i++) { + if (i < fillDimNum) { + x2_shape_[i] = 1; + } else { + x2_shape_[i] = inShape1[j++]; + } + x1_shape_[i] = inShape0[i]; + dy_shape_[i] = outShape[i]; + } + } else { + broadcasting_ = false; + for (unsigned int i = 0; i < inShape0.size(); i++) { + x2_shape_[i] = inShape1[i]; + x1_shape_[i] = inShape0[i]; + dy_shape_[i] = outShape[i]; + } + } + } + + dx1->set_shape(x1->shape()); + dx2->set_shape(x2->shape()); + dx1->set_data_type(dy->data_type()); + dx2->set_data_type(dy->data_type()); + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/arithmetic_grad.h b/mindspore/lite/src/ops/arithmetic_grad.h new file mode 100644 index 0000000000..9354c5ff99 --- /dev/null +++ b/mindspore/lite/src/ops/arithmetic_grad.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPS_ARITHMETIC_GRAD_H_ +#define MINDSPORE_LITE_SRC_OPS_ARITHMETIC_GRAD_H_ + +#include +#include +#include +#include "ir/dtype/type_id.h" +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class ArithmeticGrad : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(ArithmeticGrad, PrimitiveC); + ArithmeticGrad() = default; + explicit ArithmeticGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#else + // explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {} + ArithmeticGrad() = default; + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override { + return RET_ERROR; + } +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; + bool Broadcasting() { return this->broadcasting_; } + int NDims() { return this->ndim_; } + std::vector dyShape() { return this->dy_shape_; } + std::vector x1Shape() { return this->x1_shape_; } + std::vector x2Shape() { return this->x2_shape_; } + + protected: + bool broadcasting_ = false; + int ndim_; + std::vector dy_shape_; + std::vector x1_shape_; + std::vector x2_shape_; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_OPS_ARITHMETIC_GRAD_H_ diff --git a/mindspore/lite/src/ops/bias_grad.cc b/mindspore/lite/src/ops/bias_grad.cc index 23c01adc3e..d571ed3d09 100644 --- a/mindspore/lite/src/ops/bias_grad.cc +++ b/mindspore/lite/src/ops/bias_grad.cc @@ -48,6 +48,32 @@ std::vector BiasGrad::GetAxis() const { return std::vector(fb_vector->begin(), fb_vector->end()); } +int BiasGrad::InferShape(std::vector inputs, std::vector outputs) { + if (1 != inputs.size()) { + MS_LOG(ERROR) << "BiasGrad should have one input"; + return RET_ERROR; + } + if (1 != outputs.size()) { + MS_LOG(ERROR) << "BiasGrad should have one output"; + return RET_ERROR; + } + auto *in0 = inputs.front(); + auto *out = outputs.front(); + MS_ASSERT(in0 != nullptr); + MS_ASSERT(out != nullptr); + auto inshape = in0->shape(); + int ndim = inshape.size(); + for (int i = 0; i < ndim - 1; i++) { + inshape[i] = 1; + } + out->set_shape(inshape); + out->set_data_type(in0->data_type()); + out->SetFormat(in0->GetFormat()); + + return RET_OK; +} + + #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/bias_grad.h b/mindspore/lite/src/ops/bias_grad.h index c3729764c1..3753296174 100644 --- a/mindspore/lite/src/ops/bias_grad.h +++ b/mindspore/lite/src/ops/bias_grad.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_ -#define LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_BIAS_GRAD_H_ +#define MINDSPORE_LITE_SRC_OPS_BIAS_GRAD_H_ #include #include @@ -38,10 +38,11 @@ class BiasGrad : public PrimitiveC { BiasGrad() = default; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; + int InferShape(std::vector inputs, std::vector outputs) override; #endif std::vector GetAxis() const; }; } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_ +#endif // MINDSPORE_LITE_SRC_OPS_BIAS_GRAD_H_ diff --git a/mindspore/lite/src/ops/bn_grad_input.cc b/mindspore/lite/src/ops/bn_grad.cc similarity index 53% rename from mindspore/lite/src/ops/bn_grad_input.cc rename to mindspore/lite/src/ops/bn_grad.cc index 9aee03f81d..5fa2694e43 100644 --- a/mindspore/lite/src/ops/bn_grad_input.cc +++ b/mindspore/lite/src/ops/bn_grad.cc @@ -14,33 +14,33 @@ * limitations under the License. */ -#include "src/ops/bn_grad_input.h" +#include "src/ops/bn_grad.h" namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -float BNGradInput::GetEps() const { return this->primitive_->value.AsBNGradInput()->eps; } -int BNGradInput::GetChannels() const { return this->primitive_->value.AsBNGradInput()->channels; } +float BNGrad::GetEps() const { return this->primitive_->value.AsBNGrad()->eps; } +float BNGrad::GetMomentum() const { return this->primitive_->value.AsBNGrad()->momentum; } -void BNGradInput::SetEps(float eps) { this->primitive_->value.AsBNGradInput()->eps = eps; } -void BNGradInput::SetChannels(int channels) { this->primitive_->value.AsBNGradInput()->channels = channels; } +void BNGrad::SetEps(float eps) { this->primitive_->value.AsBNGrad()->eps = eps; } +void BNGrad::SetMomentum(float momentum) { this->primitive_->value.AsBNGrad()->momentum = momentum; } #else -int BNGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { +int BNGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != fbb); - auto attr = primitive->value_as_BNGradInput(); + auto attr = primitive->value_as_BNGrad(); if (attr == nullptr) { MS_LOG(ERROR) << "value_as_BNGradInput return nullptr"; return RET_ERROR; } - auto val_offset = schema::CreateBNGradInput(*fbb, attr->eps(), attr->channels()); - auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BNGradInput, val_offset.o); + auto val_offset = schema::CreateBNGrad(*fbb, attr->eps(), attr->momentum()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BNGrad, val_offset.o); fbb->Finish(prim_offset); return RET_OK; } -float BNGradInput::GetEps() const { return this->primitive_->value_as_BNGradInput()->eps(); } -int BNGradInput::GetChannels() const { return this->primitive_->value_as_BNGradInput()->channels(); } +float BNGrad::GetEps() const { return this->primitive_->value_as_BNGrad()->eps(); } +float BNGrad::GetMomentum() const { return this->primitive_->value_as_BNGrad()->momentum(); } #endif } // namespace lite diff --git a/mindspore/lite/src/ops/bn_grad_input.h b/mindspore/lite/src/ops/bn_grad.h similarity index 80% rename from mindspore/lite/src/ops/bn_grad_input.h rename to mindspore/lite/src/ops/bn_grad.h index aa22933f8a..e346593a53 100644 --- a/mindspore/lite/src/ops/bn_grad_input.h +++ b/mindspore/lite/src/ops/bn_grad.h @@ -25,21 +25,20 @@ namespace mindspore { namespace lite { -class BNGradInput : public PrimitiveC { +class BNGrad : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE - MS_DECLARE_PARENT(BNGradInput, PrimitiveC); - BNGradInput() = default; - explicit BNGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + MS_DECLARE_PARENT(BNGrad, PrimitiveC); + BNGrad() = default; + explicit BNGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetEps(float eps); - void SetChannels(int channels); + void SetMomentum(float momentum); #else - BNGradInput() = default; - + BNGrad() = default; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif float GetEps() const; - int GetChannels() const; + float GetMomentum() const; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/conv2d_grad_filter.cc b/mindspore/lite/src/ops/conv2d_grad_filter.cc index f3ef4d36e1..3286685294 100644 --- a/mindspore/lite/src/ops/conv2d_grad_filter.cc +++ b/mindspore/lite/src/ops/conv2d_grad_filter.cc @@ -105,5 +105,47 @@ int Conv2DGradFilter::GetActivationType() const { } #endif + +int Conv2DGradFilter::InferShape(std::vector inputs, std::vector outputs) { + if (3 != inputs.size()) { + MS_LOG(ERROR) << "Conv2d Grad Filter should have 3 inputs"; + return RET_ERROR; + } + if (1 != outputs.size()) { + MS_LOG(ERROR) << "Conv2d Grad Filter should have one output"; + return RET_ERROR; + } + + auto *in0 = inputs.at(0); + auto *in = inputs.at(2); + MS_ASSERT(out != nullptr); + + std::vector output_shape; + int *out_shape = reinterpret_cast(in->Data()); + int new_size = in->ElementsNum(); + if (in0->GetFormat() == in->GetFormat()) { + for (int i = 0; i < new_size; i++) output_shape.push_back(out_shape[i]); + } else { + if ((in0->GetFormat() == schema::Format_NHWC) && (in->GetFormat() == schema::Format_NCHW)) { + output_shape.push_back(out_shape[0]); + output_shape.push_back(out_shape[2]); + output_shape.push_back(out_shape[3]); + output_shape.push_back(out_shape[1]); + } else { + MS_LOG(ERROR) << "Shape covnert is not supported"; + return RET_ERROR; + } + } + + auto *out = outputs.at(0); + MS_ASSERT(out != nullptr); + + out->set_shape(output_shape); + out->set_data_type(in0->data_type()); + out->SetFormat(in0->GetFormat()); + + return RET_OK; +} + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/conv2d_grad_filter.h b/mindspore/lite/src/ops/conv2d_grad_filter.h index 54fd9a3bf0..96c189d3be 100644 --- a/mindspore/lite/src/ops/conv2d_grad_filter.h +++ b/mindspore/lite/src/ops/conv2d_grad_filter.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_FILTER_H_ -#define LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_FILTER_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_CONV2D_GRAD_FILTER_H_ +#define MINDSPORE_LITE_SRC_OPS_CONV2D_GRAD_FILTER_H_ #include #include @@ -53,6 +53,7 @@ class Conv2DGradFilter : public PrimitiveC { int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif + int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; int GetGroup() const; int GetChannelIn() const; @@ -74,4 +75,4 @@ class Conv2DGradFilter : public PrimitiveC { } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_FILTER_H_ +#endif // MINDSPORE_LITE_SRC_OPS_CONV2D_GRAD_FILTER_H_ diff --git a/mindspore/lite/src/ops/conv2d_grad_input.cc b/mindspore/lite/src/ops/conv2d_grad_input.cc index a8a26d2bc2..1c078b9cd6 100644 --- a/mindspore/lite/src/ops/conv2d_grad_input.cc +++ b/mindspore/lite/src/ops/conv2d_grad_input.cc @@ -103,5 +103,46 @@ int Conv2DGradInput::GetActivationType() const { } #endif + +int Conv2DGradInput::InferShape(std::vector inputs, std::vector outputs) { + if (3 != inputs.size()) { + MS_LOG(ERROR) << "Conv2d Grad Input should have 3 inputs"; + return RET_ERROR; + } + if (1 != outputs.size()) { + MS_LOG(ERROR) << "Conv2d Grad input should have one output"; + return RET_ERROR; + } + + auto *in0 = inputs.at(0); + auto *in = inputs.at(2); + MS_ASSERT(out != nullptr); + + std::vector output_shape; + int *out_shape = reinterpret_cast(in->Data()); + int new_size = in->ElementsNum(); + if (in0->GetFormat() == in->GetFormat()) { + for (int i = 0; i < new_size; i++) output_shape.push_back(out_shape[i]); + } else { + if ((in0->GetFormat() == schema::Format_NHWC) && (in->GetFormat() == schema::Format_NCHW)) { + output_shape.push_back(out_shape[0]); + output_shape.push_back(out_shape[2]); + output_shape.push_back(out_shape[3]); + output_shape.push_back(out_shape[1]); + } else { + MS_LOG(ERROR) << "Shape covnert is not supported"; + return RET_ERROR; + } + } + + auto *out = outputs.at(0); + MS_ASSERT(out != nullptr); + out->set_shape(output_shape); + out->set_data_type(in0->data_type()); + out->SetFormat(in0->GetFormat()); + + return RET_OK; +} + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/conv2d_grad_input.h b/mindspore/lite/src/ops/conv2d_grad_input.h index 7d8cd2582a..d6dab8522b 100644 --- a/mindspore/lite/src/ops/conv2d_grad_input.h +++ b/mindspore/lite/src/ops/conv2d_grad_input.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_INPUT_H_ -#define LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_INPUT_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_CONV2D_GRAD_INPUT_H_ +#define MINDSPORE_LITE_SRC_OPS_CONV2D_GRAD_INPUT_H_ #include #include @@ -53,6 +53,7 @@ class Conv2DGradInput : public PrimitiveC { int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif + int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; int GetGroup() const; int GetChannelIn() const; @@ -74,4 +75,4 @@ class Conv2DGradInput : public PrimitiveC { } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_INPUT_H_ +#endif // MINDSPORE_LITE_SRC_OPS_CONV2D_GRAD_INPUT_H_ diff --git a/mindspore/lite/src/ops/dedepthwise_conv2d.h b/mindspore/lite/src/ops/dedepthwise_conv2d.h index 142ce5b1f4..25e40421aa 100644 --- a/mindspore/lite/src/ops/dedepthwise_conv2d.h +++ b/mindspore/lite/src/ops/dedepthwise_conv2d.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_MINDSPORE_LITE_C_OPS_DE_DEPTHWISE_CONV2_D_H_ -#define LITE_MINDSPORE_LITE_C_OPS_DE_DEPTHWISE_CONV2_D_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_DEDEPTHWISE_CONV2D_H_ +#define MINDSPORE_LITE_SRC_OPS_DEDEPTHWISE_CONV2D_H_ #include #include @@ -84,4 +84,4 @@ class DeDepthwiseConv2D : public PrimitiveC { } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_C_OPS_DE_DEPTHWISE_CONV2_D_H_ +#endif // MINDSPORE_LITE_SRC_OPS_DEDEPTHWISE_CONV2D_H_ diff --git a/mindspore/lite/src/ops/depthwise_conv2d.h b/mindspore/lite/src/ops/depthwise_conv2d.h index 877f083f6f..e64d7a3262 100644 --- a/mindspore/lite/src/ops/depthwise_conv2d.h +++ b/mindspore/lite/src/ops/depthwise_conv2d.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_MINDSPORE_LITE_C_OPS_DEPTHWISE_CONV2_D_H_ -#define LITE_MINDSPORE_LITE_C_OPS_DEPTHWISE_CONV2_D_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_DEPTHWISE_CONV2D_H_ +#define MINDSPORE_LITE_SRC_OPS_DEPTHWISE_CONV2D_H_ #include #include @@ -94,4 +94,4 @@ class DepthwiseConv2D : public PrimitiveC { } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_C_OPS_DEPTHWISE_CONV2_D_H_ +#endif // MINDSPORE_LITE_SRC_OPS_DEPTHWISE_CONV2D_H_ diff --git a/mindspore/lite/src/ops/make_tuple.h b/mindspore/lite/src/ops/make_tuple.h index 04c621b587..689ae0efdc 100644 --- a/mindspore/lite/src/ops/make_tuple.h +++ b/mindspore/lite/src/ops/make_tuple.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_ -#define LITE_MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_ +#define MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_ #include #include "src/ops/primitive_c.h" @@ -37,4 +37,4 @@ class MakeTuple : public PrimitiveC { } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_ +#endif // MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_ diff --git a/mindspore/lite/src/ops/pooling_grad.cc b/mindspore/lite/src/ops/pooling_grad.cc index 654a3cb047..17bc053c49 100644 --- a/mindspore/lite/src/ops/pooling_grad.cc +++ b/mindspore/lite/src/ops/pooling_grad.cc @@ -86,5 +86,52 @@ int PoolingGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuf return RET_OK; } #endif + +int PoolingGrad::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.at(0); + MS_ASSERT(input != nullptr); + int input_h = input->shape().at(1); + int input_w = input->shape().at(2); + + auto window_h = GetWindowH(); + auto window_w = GetWindowW(); + if (GetGlobal()) { + window_h = input_h; + window_w = input_w; + } + + pad_l_ = GetPadLeft(); + pad_u_ = GetPadUp(); + pad_d_ = GetPadDown(); + pad_r_ = GetPadRight(); + if (GetPadMode() == schema::PadMode_SAME) { + int output_w = std::ceil(static_cast(input_w) / static_cast(GetStrideW())); + int output_h = std::ceil(static_cast(input_h) / static_cast(GetStrideH())); + auto pad_h_all = ((output_h - 1) * GetStrideH() + (window_h - 1) + 1 - input_h); + auto pad_w_all = ((output_w - 1) * GetStrideW() + (window_w - 1) + 1 - input_w); + if (pad_h_all < 0) { + pad_u_ = pad_d_ = 0; + } else { + pad_u_ = pad_h_all / 2; + pad_d_ = pad_h_all - pad_u_; + } + if (pad_w_all < 0) { + pad_l_ = pad_r_ = 0; + } else { + pad_l_ = pad_w_all / 2; + pad_r_ = pad_w_all - pad_l_; + } + } + auto grad_output = outputs_.at(0); + // todo: fmk type + auto output_shape = input->shape(); + grad_output->set_shape(output_shape); + grad_output->set_data_type(input->data_type()); + // todo: temp fix + grad_output->SetFormat(input->GetFormat()); + return RET_OK; +} + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/pooling_grad.h b/mindspore/lite/src/ops/pooling_grad.h index c42c5d72d6..dbafdb9254 100644 --- a/mindspore/lite/src/ops/pooling_grad.h +++ b/mindspore/lite/src/ops/pooling_grad.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_MINDSPORE_LITE_C_OPS_POOLING_GRAD_H_ -#define LITE_MINDSPORE_LITE_C_OPS_POOLING_GRAD_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_POOLING_GRAD_H_ +#define MINDSPORE_LITE_SRC_OPS_POOLING_GRAD_H_ #include #include @@ -49,6 +49,7 @@ class PoolingGrad : public PrimitiveC { int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif + int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; int GetPoolingMode() const; bool GetGlobal() const; @@ -62,8 +63,14 @@ class PoolingGrad : public PrimitiveC { int GetPadLeft() const; int GetPadRight() const; int GetRoundMode() const; + + protected: + int pad_u_ = 0; + int pad_d_ = 0; + int pad_l_ = 0; + int pad_r_ = 0; }; } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_C_OPS_POOLING_GRAD_H_ +#endif // MINDSPORE_LITE_SRC_OPS_POOLING_GRAD_H_ diff --git a/mindspore/lite/src/ops/power_grad.h b/mindspore/lite/src/ops/power_grad.h index 9cb95d696f..a3fbd79986 100644 --- a/mindspore/lite/src/ops/power_grad.h +++ b/mindspore/lite/src/ops/power_grad.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_MINDSPORE_LITE_C_OPS_POWER_GRAD_H_ -#define LITE_MINDSPORE_LITE_C_OPS_POWER_GRAD_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_POWER_GRAD_H_ +#define MINDSPORE_LITE_SRC_OPS_POWER_GRAD_H_ #include #include @@ -46,4 +46,4 @@ class PowerGrad : public PrimitiveC { } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_C_OPS_POWER_GRAD_H_ +#endif // MINDSPORE_LITE_SRC_OPS_POWER_GRAD_H_ diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 13033708dc..961cd6cf18 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -125,6 +125,21 @@ #ifdef PRIMITIVE_WRITEABLE #include "tools/converter/quantizer/quantize_util.h" #endif + +#ifdef SUPPORT_TRAIN +#include "src/ops/activation_grad.h" +#include "src/ops/apply_momentum.h" +#include "src/ops/bias_grad.h" +#include "src/ops/pooling_grad.h" +#include "src/ops/conv2d_grad_filter.h" +#include "src/ops/conv2d_grad_input.h" +#include "src/ops/power_grad.h" +#include "src/ops/softmax_cross_entropy.h" +#include "src/ops/bn_grad.h" +#include "src/ops/arithmetic_grad.h" +#endif + + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -353,6 +368,22 @@ std::shared_ptr PrimitiveC::UnPackFromPrimitive(const Primitive &pri return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Softmax") { return NewPrimitiveC(prim, inputs, quantType); +#ifdef SUPPORT_TRAIN0 + } else if ((op_type == "ReluGrad" || op_type == "Relu6Grad" || op_type == "SigmoidGrad")) { + return NewPrimitiveC(prim, inputs, quantType); + } else if ((op_type == "MaxPoolGrad") || (op_type == "MeanPoolGrad")) { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Conv2DBackpropFilter") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Conv2DBackpropInput") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "BiasAddGrad") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "ApplyMomentum") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "BatchNormGrad") { + return NewPrimitiveC(prim, inputs, quantType); +#endif } else { MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type; return nullptr; @@ -565,6 +596,32 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT return new SparseToDense(primitive); case schema::PrimitiveType_DetectionPostProcess: return new DetectionPostProcess(primitive); + +#ifdef SUPPORT_TRAIN + case schema::PrimitiveType_ActivationGrad: + return new ActivationGrad(primitive); + case schema::PrimitiveType_PoolingGrad: + return new PoolingGrad(primitive); + case schema::PrimitiveType_Conv2DGradFilter: + return new Conv2DGradFilter(primitive); + case schema::PrimitiveType_Conv2DGradInput: + return new Conv2DGradInput(primitive); + case schema::PrimitiveType_BiasGrad: + return new BiasGrad(primitive); + case schema::PrimitiveType_ApplyMomentum: + return new ApplyMomentum(primitive); + case schema::PrimitiveType_BNGrad: + return new BNGrad(primitive); + case schema::PrimitiveType_AddGrad: + return new ArithmeticGrad(primitive); + case schema::PrimitiveType_SubGrad: + return new ArithmeticGrad(primitive); + case schema::PrimitiveType_MulGrad: + return new ArithmeticGrad(primitive); + case schema::PrimitiveType_DivGrad: + return new ArithmeticGrad(primitive); +#endif + default: MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromSchemaPrimitiveT : " << schema::EnumNamePrimitiveType(op_type); @@ -779,6 +836,31 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(const schema::Primitive *primi return NewPrimitiveC(primitive); case schema::PrimitiveType_DetectionPostProcess: return NewPrimitiveC(primitive); + +#ifdef SUPPORT_TRAIN + case schema::PrimitiveType_ActivationGrad: + return NewPrimitiveC(primitive); + case schema::PrimitiveType_PoolingGrad: + return NewPrimitiveC(primitive); + case schema::PrimitiveType_Conv2DGradFilter: + return NewPrimitiveC(primitive); + case schema::PrimitiveType_Conv2DGradInput: + return NewPrimitiveC(primitive); + case schema::PrimitiveType_BiasGrad: + return NewPrimitiveC(primitive); + case schema::PrimitiveType_ApplyMomentum: + return NewPrimitiveC(primitive); + case schema::PrimitiveType_BNGrad: + return NewPrimitiveC(primitive); + case schema::PrimitiveType_AddGrad: + return NewPrimitiveC(primitive); + case schema::PrimitiveType_SubGrad: + return NewPrimitiveC(primitive); + case schema::PrimitiveType_MulGrad: + return NewPrimitiveC(primitive); + case schema::PrimitiveType_DivGrad: + return NewPrimitiveC(primitive); +#endif default: MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromSchemaPrimitive : " << schema::EnumNamePrimitiveType(op_type); diff --git a/mindspore/lite/src/ops/reduce.cc b/mindspore/lite/src/ops/reduce.cc index ea774e02b4..7b7af97b5b 100644 --- a/mindspore/lite/src/ops/reduce.cc +++ b/mindspore/lite/src/ops/reduce.cc @@ -115,7 +115,7 @@ constexpr size_t kInputSize = 1; constexpr size_t kOutputSize = 1; } // namespace int Reduce::InferShape(std::vector inputs_, std::vector outputs_) { - if (inputs_.size() != kInputSize || outputs_.size() != kOutputSize) { + if (inputs_.size() < kInputSize || outputs_.size() != kOutputSize) { return RET_ERROR; } auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/reshape.h b/mindspore/lite/src/ops/reshape.h index c81187636b..b7b3760946 100644 --- a/mindspore/lite/src/ops/reshape.h +++ b/mindspore/lite/src/ops/reshape.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_MINDSPORE_LITE_C_OPS_RESHAPE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_RESHAPE_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_RESHAPE_H_ +#define MINDSPORE_LITE_SRC_OPS_RESHAPE_H_ #include #include @@ -50,4 +50,4 @@ class Reshape : public PrimitiveC { } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_C_OPS_RESHAPE_H_ +#endif // MINDSPORE_LITE_SRC_OPS_RESHAPE_H_ diff --git a/mindspore/lite/src/ops/softmax_cross_entropy.cc b/mindspore/lite/src/ops/softmax_cross_entropy.cc index 8be8ca1d88..41e6c1e22f 100644 --- a/mindspore/lite/src/ops/softmax_cross_entropy.cc +++ b/mindspore/lite/src/ops/softmax_cross_entropy.cc @@ -51,5 +51,31 @@ int SoftmaxCrossEntropy::UnPackToFlatBuilder(const schema::Primitive *primitive, return RET_OK; } #endif + +int SoftmaxCrossEntropy::InferShape(std::vector inputs, std::vector outputs) { + if (1 > outputs.size()) { + MS_LOG(ERROR) << "SoftmaxCrossEntropy should have at least one output"; + return RET_ERROR; + } + auto *in0 = inputs.front(); + MS_ASSERT(in0 != nullptr); + auto *out = outputs.front(); + MS_ASSERT(out != nullptr); + + std::vector outshape; + outshape.push_back(1); + out->set_shape(outshape); + out->set_data_type(in0->data_type()); + + if (1 < outputs.size()) { + auto *grads = outputs.at(1); + MS_ASSERT(grads != nullptr); + grads->set_shape(in0->shape()); + grads->set_data_type(in0->data_type()); + grads->SetFormat(in0->GetFormat()); + } + return RET_OK; +} + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/softmax_cross_entropy.h b/mindspore/lite/src/ops/softmax_cross_entropy.h index 44449d0bfd..b81a435abe 100644 --- a/mindspore/lite/src/ops/softmax_cross_entropy.h +++ b/mindspore/lite/src/ops/softmax_cross_entropy.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_MINDSPORE_LITE_C_OPS_SOFTMAX_CROSS_ENTROPY_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SOFTMAX_CROSS_ENTROPY_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_SOFTMAX_CROSS_ENTROPY_H_ +#define MINDSPORE_LITE_SRC_OPS_SOFTMAX_CROSS_ENTROPY_H_ #include #include @@ -39,9 +39,11 @@ class SoftmaxCrossEntropy : public PrimitiveC { int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif + int InferShape(std::vector inputs_, std::vector outputs_) override; + std::vector GetAxis() const; }; } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_C_OPS_SOFTMAX_CROSS_ENTROPY_H_ +#endif // MINDSPORE_LITE_SRC_OPS_SOFTMAX_CROSS_ENTROPY_H_ diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index d25bd4bfd6..0690728584 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -1678,6 +1678,13 @@ PopulateParameterFunc PopulateParameterRegistry::GetParameterFunc(int type) { return populate_parameter_funcs_[schema::PrimitiveType(type)]; } +int PopulateParameterRegistry::AddPopulateParameterFunc(const schema::PrimitiveType &type, PopulateParameterFunc func) { + if ((type < schema::PrimitiveType_MIN)|| (type > schema::PrimitiveType_MAX)) + return -1; + populate_parameter_funcs_[type] = func; + return 0; +} + OpParameter *PopulateParameter(const mindspore::lite::PrimitiveC *primitive) { if (primitive == nullptr) { MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; diff --git a/mindspore/lite/src/populate_parameter.h b/mindspore/lite/src/populate_parameter.h index 5397b9bb37..de994dfc65 100644 --- a/mindspore/lite/src/populate_parameter.h +++ b/mindspore/lite/src/populate_parameter.h @@ -30,12 +30,16 @@ class PopulateParameterRegistry { ~PopulateParameterRegistry() = default; static PopulateParameterRegistry *GetInstance(); + int AddPopulateParameterFunc(const schema::PrimitiveType &type, PopulateParameterFunc func); PopulateParameterFunc GetParameterFunc(int type); protected: PopulateParameterFunc populate_parameter_funcs_[schema::PrimitiveType_MAX + 1]; }; +OpParameter *PopulateActivationParameter(const lite::PrimitiveC *primitive); +OpParameter *PopulateArithmetic(const lite::PrimitiveC *primitive); + OpParameter *PopulateParameter(const mindspore::lite::PrimitiveC *primitive); } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_POPULATE_PARAMETER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/reduce_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/reduce_base.cc index 775ee81d03..9f80f402a5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/reduce_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/reduce_base.cc @@ -37,8 +37,8 @@ constexpr size_t kOutputNum = 1; } // namespace int ReduceBaseCPUKernel::CheckInputsOutputs() { - if (in_tensors_.size() != kInputNum) { - MS_LOG(ERROR) << "Reduce inputs size should be " << kInputNum << " but got " << in_tensors_.size(); + if (in_tensors_.size() < kInputNum) { + MS_LOG(ERROR) << "Reduce inputs size should be at least " << kInputNum << " but got " << in_tensors_.size(); return RET_ERROR; } if (out_tensors_.size() != kOutputNum) { @@ -99,7 +99,15 @@ int ReduceBaseCPUKernel::Init() { if (reduce_param == nullptr) { return RET_NULL_PTR; } - num_axes_ = reduce_param->num_axes_; + if (in_tensors_.size() > 1) { + auto axes_ptr = in_tensors_.at(1); + num_axes_ = axes_ptr->ElementsNum(); + memcpy(axes_, axes_ptr->Data(), axes_ptr->Size()); + } else { + num_axes_ = reduce_param->num_axes_; + memcpy(axes_, reduce_param->axes_, sizeof(reduce_param->axes_)); + } + mode_ = reduce_param->mode_; memcpy(axes_, reduce_param->axes_, sizeof(reduce_param->axes_)); reduce_to_end_ = reduce_param->reduce_to_end_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc index 6cc4999499..0e15719b89 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc @@ -15,6 +15,7 @@ */ #include "src/runtime/kernel/arm/fp32_grad/activation_grad.h" +#include "nnacl/fp32_grad/activation_grad.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "src/runtime/runtime_api.h" @@ -24,41 +25,38 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; -using mindspore::schema::ActivationGradType_HSWISH; -using mindspore::schema::ActivationGradType_LEAKY_RELU; -using mindspore::schema::ActivationGradType_RELU; -using mindspore::schema::ActivationGradType_RELU6; +using mindspore::schema::ActivationType_HSWISH; +using mindspore::schema::ActivationType_LEAKY_RELU; +using mindspore::schema::ActivationType_RELU; +using mindspore::schema::ActivationType_RELU6; using mindspore::schema::PrimitiveType_ActivationGrad; namespace mindspore::kernel { -int ActivationGradCPUKernel::Init() { - outputs_[0]->set_shape(inputs_[0]->shape()); - return RET_OK; -} +int ActivationGradCPUKernel::Init() { return RET_OK; } int ActivationGradCPUKernel::ReSize() { return RET_OK; } int ActivationGradCPUKernel::DoActivation(int task_id) { - auto yt_addr = reinterpret_cast(inputs_.at(0)->Data()); - auto input_addr = reinterpret_cast(inputs_.at(1)->Data()); - auto output_addr = reinterpret_cast(outputs_.at(0)->Data()); - auto length = inputs_.at(0)->ElementsNum(); + auto yt_addr = reinterpret_cast(in_tensors_.at(0)->Data()); + auto input_addr = reinterpret_cast(in_tensors_.at(1)->Data()); + auto output_addr = reinterpret_cast(out_tensors_.at(0)->Data()); + int length = in_tensors_.at(0)->ElementsNum(); auto error_code = RET_OK; - if (type_ == schema::ActivationGradType_RELU) { + if (param_act_grad_->type_ == schema::ActivationType_RELU) { error_code = ReluGrad(yt_addr, input_addr, length, output_addr); - } else if (type_ == schema::ActivationGradType_RELU6) { + } else if (param_act_grad_->type_ == schema::ActivationType_RELU6) { error_code = Relu6Grad(yt_addr, input_addr, length, output_addr); - } else if (type_ == schema::ActivationGradType_LEAKY_RELU) { - error_code = LReluGrad(yt_addr, input_addr, length, output_addr, alpha_); - } else if (type_ == schema::ActivationGradType_SIGMOID) { + } else if (param_act_grad_->type_ == schema::ActivationType_LEAKY_RELU) { + error_code = LReluGrad(yt_addr, input_addr, length, output_addr, param_act_grad_->alpha_); + } else if (param_act_grad_->type_ == schema::ActivationType_SIGMOID) { error_code = SigmoidGrad(yt_addr, input_addr, length, output_addr); - } else if (type_ == schema::ActivationGradType_TANH) { + } else if (param_act_grad_->type_ == schema::ActivationType_TANH) { error_code = TanhGrad(yt_addr, input_addr, length, output_addr); - } else if (type_ == schema::ActivationGradType_HSWISH) { + } else if (param_act_grad_->type_ == schema::ActivationType_HSWISH) { error_code = HSwishGrad(yt_addr, input_addr, length, output_addr); - } else if (type_ == schema::ActivationGradType_HSIGMOID) { + } else if (param_act_grad_->type_ == schema::ActivationType_HSIGMOID) { error_code = HSigmoidGrad(yt_addr, input_addr, length, output_addr); } else { MS_LOG(ERROR) << "Activation type error"; @@ -81,6 +79,12 @@ int ActivationGradRun(void *cdata, int task_id) { } int ActivationGradCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return ret; + } + int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, ActivationGradRun, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Activation function error error_code[" << error_code << "]"; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h index b6236385b2..7f001e7109 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h @@ -20,8 +20,7 @@ #include #include "src/lite_kernel.h" #include "ir/anf.h" - -#include "nnacl/activation_grad.h" +#include "nnacl/fp32/activation.h" namespace mindspore::kernel { class ActivationGradCPUKernel : public LiteKernel { @@ -30,9 +29,7 @@ class ActivationGradCPUKernel : public LiteKernel { const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) : LiteKernel(param, inputs, outputs, ctx, primitive) { - ActivationGradParameter *param_act_grad = reinterpret_cast(param); - type_ = param_act_grad->type_; - alpha_ = param_act_grad->alpha_; + param_act_grad_ = reinterpret_cast(param); } ~ActivationGradCPUKernel() override = default; @@ -43,9 +40,9 @@ class ActivationGradCPUKernel : public LiteKernel { private: int thread_count_; - int type_; - float alpha_; + ActivationParameter *param_act_grad_; }; + } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ACTIVATION_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc new file mode 100644 index 0000000000..4fa49f87bb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc @@ -0,0 +1,105 @@ + +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32_grad/apply_momentum.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/fp32/nchw2nhwc.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_ApplyMomentum; + +namespace mindspore::kernel { + +int ApplyMomentumCPUKernel::ReSize() { return RET_OK; } + +int ApplyMomentumCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } + + auto weight = reinterpret_cast(in_tensors_[0]->Data()); + auto accumulate = reinterpret_cast(in_tensors_[1]->Data()); + float learning_rate = reinterpret_cast(in_tensors_[2]->Data())[0]; + auto gradient = reinterpret_cast(in_tensors_[3]->Data()); + float moment = reinterpret_cast(in_tensors_[4]->Data())[0]; + size_t elem_num = in_tensors_[0]->ElementsNum(); + + // align format + if (in_tensors_[3]->shape().size() == 4 && + in_tensors_[3]->GetFormat() == schema::Format_NCHW && + in_tensors_[0]->GetFormat() == schema::Format_KHWC) { + PackNCHWToNHWCFp32(gradient, workspace, in_tensors_[0]->Batch(), in_tensors_[0]->Height() * in_tensors_[0]->Width(), + in_tensors_[0]->Channel()); + } else { + memcpy(workspace, gradient, in_tensors_[3]->ElementsNum() * sizeof(float)); + } + + for (size_t i = 0; i < elem_num; ++i) { + accumulate[i] = accumulate[i] * moment + workspace[i]; // * (1.0 - moment); + weight[i] -= accumulate[i] * learning_rate; + } + return RET_OK; +} + +int ApplyMomentumCPUKernel::Init() { + // Only for test with uninitialized Data + size_t elem_num = in_tensors_[0]->ElementsNum(); + auto accumulate = reinterpret_cast(in_tensors_[1]->Data()); + for (int i =0; i < elem_num; i++) accumulate[i] = 0.0; + + workspace = new float[elem_num]; + return 0; +} +#if 0 +OpParameter *PopulateApplyMomentumParameter(const lite::Primitive *primitive) { + OpParameter *param = new (std::nothrow) OpParameter(); + if (param == nullptr) { + MS_LOG(ERROR) << "new Param for OptMomentum failed."; + return nullptr; + } + param->type_ = primitive->Type(); + return param; +} +#endif + +kernel::LiteKernel *CpuApplyMomentumFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, const lite::PrimitiveC *primitive) { + MS_ASSERT(desc.type == schema::PrimitiveType_ApplyMomentum); + auto *kernel = new (std::nothrow) ApplyMomentumCPUKernel(opParameter, inputs, outputs, ctx, primitive); + MS_ASSERT(kernel != nullptr); + + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ApplyMomentum, CpuApplyMomentumFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/opt_momentum.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.h similarity index 67% rename from mindspore/lite/src/runtime/kernel/arm/fp32_grad/opt_momentum.h rename to mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.h index 8603363003..c2d9f6ed31 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/opt_momentum.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.h @@ -14,28 +14,32 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_OPT_MOMENTUM_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_OPT_MOMENTUM_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_APPLY_MOMENTUM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_APPLY_MOMENTUM_H_ #include #include "src/lite_kernel.h" #include "ir/anf.h" namespace mindspore::kernel { -class OptMomentumCPUKernel : public LiteKernel { +class ApplyMomentumCPUKernel : public LiteKernel { public: - explicit OptMomentumCPUKernel(OpParameter *parameter, const std::vector &inputs, + explicit ApplyMomentumCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} - ~OptMomentumCPUKernel() override {} + ~ApplyMomentumCPUKernel() override {delete [] workspace;} int Init() override; int ReSize() override; int Run() override; private: + float *workspace; }; + +// OpParameter *PopulateApplyMomentumParameter(const lite::Primitive *primitive); + } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_OPT_MOMENTUM_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_APPLY_MOMENTUM_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc index 6e1cc3c696..fa1c99a67d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc @@ -14,11 +14,11 @@ * limitations under the License. */ +#include "src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "nnacl/fp32_grad/reduce_grad.h" #include "nnacl/fp32_grad/arithmetic_grad.h" -#include "src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; @@ -33,108 +33,41 @@ constexpr int kArithGradOpOutputNum = 2; } // namespace int ArithmeticGradCPUKernel::Init() { - auto ret = InferShape(); - return ret; -} - -int ArithmeticGradCPUKernel::InferShape() { - if (inputs_.size() != kArithGradOpInputNum) { - MS_LOG(ERROR) << "The number of input must be " << kArithGradOpInputNum; - return RET_ERROR; - } - if (outputs_.size() != kArithGradOpOutputNum) { - MS_LOG(ERROR) << "The number of output must be " << kArithGradOpOutputNum; - return RET_ERROR; - } - auto dy = inputs_[0]; - auto x1 = inputs_[1]; - auto x2 = inputs_[2]; - auto dx1 = outputs_[0]; - auto dx2 = outputs_[1]; + auto dx1 = out_tensors_[0]; + auto dx2 = out_tensors_[1]; - MS_ASSERT(dy != nullptr); - MS_ASSERT(x1 != nullptr); - MS_ASSERT(x2 != nullptr); MS_ASSERT(dx1 != nullptr); MS_ASSERT(dx2 != nullptr); - auto inShape0 = x1->shape(); - auto inShape1 = x2->shape(); - auto outShape = dy->shape(); - - if ((type() == PrimitiveType_AddGrad) || (type() == PrimitiveType_SubGrad)) { - arithmeticParameter_->ndim_ = outShape.size(); - auto fillDimNum0 = outShape.size() - inShape0.size(); - auto fillDimNum1 = outShape.size() - inShape1.size(); - int j0 = 0; - int j1 = 0; - for (unsigned int i = 0; i < outShape.size(); i++) { - arithmeticParameter_->in_shape0_[i] = (i < fillDimNum0) ? 1 : inShape0[j0++]; - arithmeticParameter_->in_shape1_[i] = (i < fillDimNum1) ? 1 : inShape1[j1++]; - arithmeticParameter_->out_shape_[i] = outShape[i]; - } - } else { + if ((Type() == PrimitiveType_MulGrad) || (Type() == PrimitiveType_DivGrad)) { // if (inShape0.size() < inShape1.size()) if (dx1->ElementsNum() < dx2->ElementsNum()) { - arithmeticParameter_->ndim_ = inShape1.size(); - if (type() == PrimitiveType_MulGrad) + if (Type() == PrimitiveType_MulGrad) arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul2L; - else if (type() == PrimitiveType_DivGrad) + else if (Type() == PrimitiveType_DivGrad) arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradDiv2L; - auto fillDimNum = inShape1.size() - inShape0.size(); // This will not work for batch! - int j = 0; - for (unsigned int i = 0; i < inShape1.size(); i++) { - if (i < fillDimNum) { - arithmeticParameter_->in_shape1_[i] = 1; - } else { - arithmeticParameter_->in_shape1_[i] = inShape0[j++]; - } - arithmeticParameter_->in_shape0_[i] = inShape1[i]; - arithmeticParameter_->out_shape_[i] = outShape[i]; - } } else if (dx2->ElementsNum() < dx1->ElementsNum()) { // if (inShape0.size() > inShape1.size()) - arithmeticParameter_->ndim_ = inShape0.size(); - if (type() == PrimitiveType_MulGrad) + if (Type() == PrimitiveType_MulGrad) arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul1L; - else if (type() == PrimitiveType_DivGrad) + else if (Type() == PrimitiveType_DivGrad) arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradDiv1L; - arithmeticParameter_->broadcasting_ = true; - arithmeticParameter_->ndim_ = inShape0.size(); - int j = 0; - auto fillDimNum = inShape0.size() - inShape1.size(); - for (unsigned int i = 0; i < inShape0.size(); i++) { - if (i < fillDimNum) { - arithmeticParameter_->in_shape1_[i] = 1; - } else { - arithmeticParameter_->in_shape1_[i] = inShape1[j++]; - } - arithmeticParameter_->in_shape0_[i] = inShape0[i]; - arithmeticParameter_->out_shape_[i] = outShape[i]; - } - } else { - arithmeticParameter_->broadcasting_ = false; - for (unsigned int i = 0; i < inShape0.size(); i++) { - arithmeticParameter_->in_shape1_[i] = inShape1[i]; - arithmeticParameter_->in_shape0_[i] = inShape0[i]; - arithmeticParameter_->out_shape_[i] = outShape[i]; - } } - tile_data0 = new (std::nothrow) float[inputs_.at(0)->ElementsNum()]; + tile_data0 = new (std::nothrow) float[in_tensors_.at(0)->ElementsNum()]; if (tile_data0 == nullptr) { MS_LOG(ERROR) << "new data0 fail!"; return RET_ERROR; } - tile_data1 = new (std::nothrow) float[inputs_.at(0)->ElementsNum()]; + tile_data1 = new (std::nothrow) float[in_tensors_.at(0)->ElementsNum()]; if (tile_data1 == nullptr) { MS_LOG(ERROR) << "new data1 fail!"; delete tile_data0; return RET_ERROR; } - if (type() == PrimitiveType_DivGrad) { - tile_data2 = new (std::nothrow) float[inputs_.at(0)->ElementsNum()]; + if (Type() == PrimitiveType_DivGrad) { + tile_data2 = new (std::nothrow) float[in_tensors_.at(0)->ElementsNum()]; if (tile_data2 == nullptr) { MS_LOG(ERROR) << "new data2 fail!"; delete tile_data0; @@ -144,10 +77,6 @@ int ArithmeticGradCPUKernel::InferShape() { } } - dx1->set_shape(x1->shape()); - dx2->set_shape(x2->shape()); - dx1->set_data_type(dy->data_type()); - dx2->set_data_type(dy->data_type()); return RET_OK; } @@ -187,16 +116,16 @@ void ArithmeticGradCPUKernel::ArithmeticGradSub(float *dy, int dy_size, float *d void ArithmeticGradCPUKernel::ArithmeticGradMul(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size) { - auto x1_data = reinterpret_cast(inputs_[1]->Data()); - auto x2_data = reinterpret_cast(inputs_[2]->Data()); + auto x1_data = reinterpret_cast(in_tensors_[1]->Data()); + auto x2_data = reinterpret_cast(in_tensors_[2]->Data()); ElementMul(dy, x1_data, dx2, dy_size); ElementMul(dy, x2_data, dx1, dy_size); } void ArithmeticGradCPUKernel::ArithmeticGradMul1L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size) { - auto x1_data = reinterpret_cast(inputs_[1]->Data()); - auto x2_data = reinterpret_cast(inputs_[2]->Data()); + auto x1_data = reinterpret_cast(in_tensors_[1]->Data()); + auto x2_data = reinterpret_cast(in_tensors_[2]->Data()); ElementMul(dy, x1_data, tile_data0, dy_size); ReduceSumByAxes(tile_data0, arithmeticParameter_->in_shape0_, dx2, arithmeticParameter_->in_shape1_, arithmeticParameter_->ndim_); @@ -206,8 +135,8 @@ void ArithmeticGradCPUKernel::ArithmeticGradMul1L(float *dy, int dy_size, float void ArithmeticGradCPUKernel::ArithmeticGradMul2L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size) { - auto x1_data = reinterpret_cast(inputs_[1]->Data()); - auto x2_data = reinterpret_cast(inputs_[2]->Data()); + auto x1_data = reinterpret_cast(in_tensors_[1]->Data()); + auto x2_data = reinterpret_cast(in_tensors_[2]->Data()); ElementMul(dy, x2_data, tile_data0, dy_size); ReduceSumByAxes(tile_data0, arithmeticParameter_->in_shape0_, dx1, arithmeticParameter_->in_shape1_, arithmeticParameter_->ndim_); @@ -217,16 +146,16 @@ void ArithmeticGradCPUKernel::ArithmeticGradMul2L(float *dy, int dy_size, float void ArithmeticGradCPUKernel::ArithmeticGradDiv(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size) { - auto x1 = reinterpret_cast(inputs_[1]->Data()); - auto x2 = reinterpret_cast(inputs_[2]->Data()); + auto x1 = reinterpret_cast(in_tensors_[1]->Data()); + auto x2 = reinterpret_cast(in_tensors_[2]->Data()); ElementDiv(dy, x2, dx1, dy_size); ElementMulAndDivNegSquare(dy, x1, x2, dx2, dy_size); } void ArithmeticGradCPUKernel::ArithmeticGradDiv1L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size) { - auto x1_data = reinterpret_cast(inputs_[1]->Data()); - auto x2_data = reinterpret_cast(inputs_[2]->Data()); + auto x1_data = reinterpret_cast(in_tensors_[1]->Data()); + auto x2_data = reinterpret_cast(in_tensors_[2]->Data()); ElementMul(x2_data, x2_data, dx2, dx2_size); ElementMul(x1_data, dy, dx1, dy_size); // use dx1 buffer @@ -243,8 +172,8 @@ void ArithmeticGradCPUKernel::ArithmeticGradDiv1L(float *dy, int dy_size, float void ArithmeticGradCPUKernel::ArithmeticGradDiv2L(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2, int dx2_size) { - auto x1_data = reinterpret_cast(inputs_[1]->Data()); - auto x2_data = reinterpret_cast(inputs_[2]->Data()); + auto x1_data = reinterpret_cast(in_tensors_[1]->Data()); + auto x2_data = reinterpret_cast(in_tensors_[2]->Data()); // dx1 = dy/x2 ElementDiv(dy, x2_data, tile_data0, dy_size); // first multiply into temp @@ -259,13 +188,13 @@ void ArithmeticGradCPUKernel::ArithmeticGradDiv2L(float *dy, int dy_size, float int ArithmeticGradCPUKernel::ReSize() { return RET_OK; } int ArithmeticGradCPUKernel::Run() { - auto dy = reinterpret_cast(inputs_[0]->Data()); - auto dx1 = reinterpret_cast(outputs_[0]->Data()); - auto dx2 = reinterpret_cast(outputs_[1]->Data()); + auto dy = reinterpret_cast(in_tensors_[0]->Data()); + auto dx1 = reinterpret_cast(out_tensors_[0]->Data()); + auto dx2 = reinterpret_cast(out_tensors_[1]->Data()); - size_t dy_size = inputs_.at(0)->ElementsNum(); - size_t dx1_size = outputs_.at(0)->ElementsNum(); - size_t dx2_size = outputs_[1]->ElementsNum(); + size_t dy_size = in_tensors_.at(0)->ElementsNum(); + size_t dx1_size = out_tensors_.at(0)->ElementsNum(); + size_t dx2_size = out_tensors_[1]->ElementsNum(); (this->*arithmetic_grad_)(dy, dy_size, dx1, dx1_size, dx2, dx2_size); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h index 7b5b94e0d1..f5548c05e5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.h @@ -40,7 +40,7 @@ class ArithmeticGradCPUKernel : public LiteKernel { const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive), tile_data0(NULL), tile_data1(NULL), tile_data2(NULL) { - switch (type()) { + switch (Type()) { case PrimitiveType_MulGrad: arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul; // this will be adjusted in InferShape break; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.cc index 4f4537598e..f6f914c34e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.cc @@ -27,33 +27,9 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_BiasGrad; namespace mindspore::kernel { -int BiasGradCPUKernel::InferShape() { - if (1 != this->inputs_.size()) { - MS_LOG(ERROR) << "BiasGrad should have one input"; - return RET_ERROR; - } - if (1 != this->outputs_.size()) { - MS_LOG(ERROR) << "BiasGrad should have one output"; - return RET_ERROR; - } - auto *in0 = inputs_.front(); - auto *out = outputs_.front(); - MS_ASSERT(in0 != nullptr); - MS_ASSERT(out != nullptr); - auto inshape = in0->shape(); - int ndim = inshape.size(); - for (int i = 0; i < ndim - 1; i++) { - inshape[i] = 1; - } - out->set_shape(inshape); - out->set_data_type(in0->data_type()); - return RET_OK; -} int BiasGradCPUKernel::Init() { - MS_ASSERT(InferShape() == RET_OK); - - auto dims = inputs_[0]->shape(); + auto dims = in_tensors_[0]->shape(); bias_param->ndim_ = dims.size(); for (unsigned int i = 0; i < bias_param->ndim_; i++) { bias_param->in_shape0_[i] = dims[i]; @@ -75,8 +51,8 @@ int BiasGradCPUKernel::Run() { MS_LOG(ERROR) << "Prepare failed."; return RET_ERROR; } - auto in = reinterpret_cast(inputs_.at(0)->Data()); - auto out = reinterpret_cast(outputs_.at(0)->Data()); + auto in = reinterpret_cast(in_tensors_.at(0)->Data()); + auto out = reinterpret_cast(out_tensors_.at(0)->Data()); size_t nhw_size = 1; size_t channels = bias_param->in_shape0_[bias_param->ndim_ - 1]; // C in NHWC diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.h index 1b8596210b..99427556f9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_GRAD_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_GRAD_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BIAS_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BIAS_GRAD_H_ #include #include "src/lite_kernel.h" @@ -35,7 +35,6 @@ class BiasGradCPUKernel : public LiteKernel { ~BiasGradCPUKernel() override = default; int Init() override; - int InferShape(); int ReSize() override; int Run() override; @@ -44,4 +43,4 @@ class BiasGradCPUKernel : public LiteKernel { }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_GRAD_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BIAS_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc index 62cff244dd..6aba3fe8c6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc @@ -14,11 +14,11 @@ * limitations under the License. */ +#include "src/runtime/kernel/arm/fp32_grad/bn_grad.h" #include #include #include "schema/model_generated.h" #include "src/kernel_registry.h" -#include "src/runtime/kernel/arm/fp32_grad/bn_grad.h" #include "nnacl/fp32_grad/batch_norm.h" #include "include/errorcode.h" @@ -27,79 +27,103 @@ using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; // using mindspore::lite::REG_OP; -using mindspore::schema::PrimitiveType_BNGradInput; +using mindspore::schema::PrimitiveType_BNGrad; +/* +{dy} +{x } +{scale } +{save_mean } +{save_inv_variance } +*/ namespace mindspore::kernel { -int BNGradInputCPUKernel::Init() { - auto bn_param = reinterpret_cast(opParameter); - workspace_size = 5 * bn_param->channels; - workspace = new (std::nothrow) float[workspace_size]; - if (workspace == nullptr) { - MS_LOG(ERROR) << "new workspace fail!"; - return RET_ERROR; - } - if (2 != this->inputs_.size()) { - MS_LOG(ERROR) << "Conv2d Grad should has 2 inputs"; - return RET_ERROR; +#if 0 +OpParameter *PopulateBNGradParameter(const lite::Primitive *primitive) { + BNGradParameter *param = new (std::nothrow) BNGradParameter(); + if (param == nullptr) { + MS_LOG(ERROR) << "new Param for conv grad filter failed."; + return nullptr; } - if (1 != this->outputs_.size()) { - MS_LOG(ERROR) << "Conv2d Grad should has one output"; + param->op_parameter_.type_ = primitive->Type(); + + auto bngrad_primitive = primitive->Value()->value_as_BNGrad(); + param->epsilon_ = bngrad_primitive->eps(); + param->momentum_ = bngrad_primitive->momentum(); + return reinterpret_cast(param); +} +#endif +int BNGradCPUKernel::Init() { + auto *input_x = in_tensors_.at(1); + int channels = input_x->shape().at(kNHWC_C); + workspace_size = 5 * channels; + workspace = new (std::nothrow) float[workspace_size]; + if (workspace == nullptr) { + MS_LOG(ERROR) << "new workspace fail!"; return RET_ERROR; } - auto *input_tensor = inputs_.at(0); - auto *out_tensor = outputs_.at(0); - auto in_shape = input_tensor->shape(); - out_tensor->set_shape(in_shape); - out_tensor->set_data_type(input_tensor->data_type()); return RET_OK; } -int BNGradInputCPUKernel::ReSize() { return RET_OK; } +int BNGradCPUKernel::ReSize() { return RET_OK; } -int BNGradInputCPUKernel::Run() { - auto *input_x = inputs_.at(0); - auto *input_yt = inputs_.at(1); - auto *input_scale = inputs_.at(2); - auto *output_grad = outputs_.at(0); - auto bn_param = reinterpret_cast(opParameter); - int batch = bn_param->batch; - int channels = bn_param->channels; - int spatial = bn_param->spatial; - float eps = bn_param->eps; +int BNGradCPUKernel::Run() { + // std::cout << "run succ" << std::endl; + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } + auto bn_param = reinterpret_cast(op_parameter_); + auto *input_yt = in_tensors_.at(0); + auto *input_x = in_tensors_.at(1); + auto *input_scale = in_tensors_.at(2); + auto *output_dx = out_tensors_.at(0); + auto *output_scale = out_tensors_.at(1); + auto *output_bias = out_tensors_.at(2); + // Tensor *bias = input[5]; + int batch = input_x->Batch(); + int channels = input_x->Channel(); + int spatial = input_x->Height() * input_x->Width(); + float eps = bn_param->epsilon_; std::fill(workspace, workspace + workspace_size, 0.f); - float *mean = workspace; - float *variance = mean + channels; - float *mean_delta = variance + channels; + float *invar = mean + channels; + float *mean_delta = invar + channels; float *variance_delta = mean_delta + channels; float *mean_add_delta = variance_delta + channels; float *x = reinterpret_cast(input_x->Data()); float *yt = reinterpret_cast(input_yt->Data()); float *scale = reinterpret_cast(input_scale->Data()); - float *out = reinterpret_cast(output_grad->Data()); + float *dx = reinterpret_cast(output_dx->Data()); + float *dscale = reinterpret_cast(output_scale->Data()); + float *dbias = reinterpret_cast(output_bias->Data()); - std::copy(yt, yt + batch * channels * spatial, out); - meanVar(x, batch, spatial, channels, mean, variance); - scaleBias(scale, batch, channels, spatial, out); - meanDelta(out, spatial, channels, eps, variance, mean_delta); - varianceDelta(x, out, mean, variance, batch, channels, spatial, eps, variance_delta); + std::copy(yt, yt + batch * channels * spatial, dx); + meanVar(x, batch, spatial, channels, eps, mean, invar); + scaleBias(scale, batch, channels, spatial, dx); + meanDelta(dx, spatial, channels, invar, mean_delta); + varianceDelta(x, dx, mean, invar, batch, channels, spatial, variance_delta); meanAdd(x, mean, variance_delta, batch, channels, spatial, mean_add_delta, mean_delta); - NormalizeDelta(x, mean, variance, mean_delta, variance_delta, batch, channels, eps, spatial, out); + NormalizeDelta(x, mean, invar, mean_delta, variance_delta, batch, channels, spatial, dx); + // dbias + sumSpatialBatch(yt, batch * spatial, channels, dbias); + // dscale + backwardScale(x, mean, invar, yt, batch, channels, spatial, dscale); return RET_OK; } -kernel::LiteKernel *CpuBNGradInputFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc, - const mindspore::lite::PrimitiveC *primitive) { +kernel::LiteKernel *CpuBNGradFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_BNGradInput); - auto *kernel = new (std::nothrow) BNGradInputCPUKernel(opParameter, inputs, outputs, ctx, primitive); + MS_ASSERT(desc.type == schema::PrimitiveType_BNGrad); + auto *kernel = new (std::nothrow) BNGradCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { - MS_LOG(ERROR) << "new BNGradInputCPUKernel fail!"; + MS_LOG(ERROR) << "new BNGradCPUKernel fail!"; return nullptr; } auto ret = kernel->Init(); @@ -112,5 +136,5 @@ kernel::LiteKernel *CpuBNGradInputFp32KernelCreator(const std::vector #include "src/lite_kernel.h" #include "ir/anf.h" + namespace mindspore::kernel { -class BNGradInputCPUKernel : public LiteKernel { + + + +class BNGradCPUKernel : public LiteKernel { public: - explicit BNGradInputCPUKernel(OpParameter *parameter, const std::vector &inputs, + explicit BNGradCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} - ~BNGradInputCPUKernel() override { delete workspace; } + ~BNGradCPUKernel() override { delete workspace; } int Init() override; int ReSize() override; @@ -38,5 +42,8 @@ class BNGradInputCPUKernel : public LiteKernel { float *workspace; int workspace_size; }; + +// OpParameter *PopulateBNGradParameter(const lite::Primitive *primitive); + } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BNGRAD_INPUT_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BN_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc new file mode 100644 index 0000000000..30c4294ad7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc @@ -0,0 +1,121 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32_grad/convolution.h" +#include "nnacl/fp32_grad/pack_ext.h" +#include "nnacl/fp32_grad/gemm.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +int ConvolutionTrainCPUKernel::Init() { + auto conv_param_ = reinterpret_cast(op_parameter_); + auto *input_x = in_tensors_.at(kInputIndex); + auto *input_weight = in_tensors_.at(kWeightIndex); + auto *out_y = out_tensors_.at(kOutputIndex); + + conv_param_->output_batch_ = out_y->shape().at(kNHWC_N); + conv_param_->input_batch_ = input_x->shape().at(kNHWC_N); + conv_param_->input_h_ = input_x->shape().at(kNHWC_H); + conv_param_->input_w_ = input_x->shape().at(kNHWC_W); + conv_param_->output_h_ = out_y->shape().at(kNHWC_H); + conv_param_->output_w_ = out_y->shape().at(kNHWC_W); + conv_param_->input_channel_ = input_x->shape().at(kNHWC_C); + conv_param_->output_channel_ = input_weight->shape().at(kNHWC_N); + conv_param_->kernel_h_ = input_weight->shape().at(kNHWC_H); + conv_param_->kernel_w_ = input_weight->shape().at(kNHWC_W); + + int ws_size = conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->kernel_h_ * conv_param_->kernel_w_ * + conv_param_->input_channel_ / conv_param_->group_; + + workspace = new float[ws_size]; + return RET_OK; +} + +int ConvolutionTrainCPUKernel::ReSize() { return RET_OK; } + +int ConvolutionTrainCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } + auto conv_param_ = reinterpret_cast(op_parameter_); + auto *input_x = in_tensors_.at(kInputIndex); + auto *input_w = in_tensors_.at(kWeightIndex); + auto *out_y = out_tensors_.at(kOutputIndex); + + auto x_addr = reinterpret_cast(input_x->Data()); + auto y_addr = reinterpret_cast(out_y->Data()); + auto w_addr = reinterpret_cast(input_w->Data()); + + int i, j; + int nweights = input_w->ElementsNum(); + int in_ch = conv_param_->input_channel_; + int in_h = conv_param_->input_h_; + int in_w = conv_param_->input_w_; + int k_h = conv_param_->kernel_h_; + int k_w = conv_param_->kernel_w_; + int batch = conv_param_->output_batch_; + int out_ch = conv_param_->output_channel_; // out_y->shape()[3]; + int groups = conv_param_->group_; + int out_h = conv_param_->output_h_; + int out_w = conv_param_->output_w_; + int m = out_h * out_w; + int n = out_ch / groups; + int k = k_h * k_w * in_ch / groups; + + memset(y_addr, 0, out_y->Size()); + + for (i = 0; i < batch; ++i) { + for (j = 0; j < groups; ++j) { + float *mat_a = workspace; + float *mat_b = w_addr + j * nweights / groups; + float *mat_c = y_addr + (i * groups) * n * m + j * (out_ch / groups); + float *im = x_addr + (i * groups) * (in_ch / groups) * in_h * in_w + j * (in_ch / groups); + im2col_hwc(im, mat_a, conv_param_); + gemm(0, 1, m, n, k, 1, mat_a, k, mat_b, k, 1, mat_c, out_ch); + } + } + + // std::cout << "run succ" << std::endl; + return RET_OK; +} + +kernel::LiteKernel *CpuConvTrainFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, const lite::PrimitiveC *primitive) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); + + auto *kernel = new (std::nothrow) ConvolutionTrainCPUKernel(opParameter, inputs, outputs, ctx, primitive); + MS_ASSERT(kernel != nullptr); + + auto ret = kernel->Init(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.h new file mode 100644 index 0000000000..5a44c11e3a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" + +namespace mindspore::kernel { +class ConvolutionTrainCPUKernel : public LiteKernel { + public: + explicit ConvolutionTrainCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx, + const lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + ~ConvolutionTrainCPUKernel() override { delete [] workspace; } + + int Init() override; + int ReSize() override; + int Run() override; + + private: + float *workspace; +}; + +kernel::LiteKernel *CpuConvTrainFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, const lite::PrimitiveC *primitive); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc index cd69162d52..4e5260e8ff 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc @@ -33,30 +33,24 @@ int ConvolutionGradFilterCPUKernel::Init() { // x is in input 1 // dw is output 0 - if (2 != this->inputs_.size()) { - MS_LOG(ERROR) << "Conv2d Grad should has 2 inputs"; - return RET_ERROR; - } - if (1 != this->outputs_.size()) { - MS_LOG(ERROR) << "Conv2d Grad should has one output"; - return RET_ERROR; - } - - auto *input_tensor = inputs_.at(1); - MS_ASSERT(input_tensor != nullptr); - auto *dy = inputs_.at(0); - MS_ASSERT(dy != nullptr); - auto *weight_tensor = outputs_.at(0); + auto *x_tensor = in_tensors_.at(1); + MS_ASSERT(x_tensor != nullptr); + auto *dy_tensor = in_tensors_.at(0); + MS_ASSERT(dy_tensor != nullptr); + auto *weight_tensor = out_tensors_.at(0); MS_ASSERT(weight_tensor != nullptr); - auto conv_param = reinterpret_cast(opParameter); - conv_param->output_batch_ = this->inputs_.at(0)->shape().at(kNHWC_N); - conv_param->input_batch_ = this->inputs_.at(1)->shape().at(kNHWC_N); - conv_param->input_h_ = this->inputs_.at(1)->shape().at(kNHWC_H); - conv_param->input_w_ = this->inputs_.at(1)->shape().at(kNHWC_W); - // assume OutCh|kh|kw|In - conv_param->input_channel_ = this->inputs_.at(1)->shape().at(kNHWC_C); - conv_param->output_channel_ = this->outputs_.at(0)->shape().at(kNHWC_N); + auto conv_param = reinterpret_cast(op_parameter_); + conv_param->output_batch_ = dy_tensor->shape().at(kNHWC_N); + conv_param->input_batch_ = x_tensor->shape().at(kNHWC_N); + conv_param->input_h_ = x_tensor->shape().at(kNHWC_H); + conv_param->input_w_ = x_tensor->shape().at(kNHWC_W); + // assume OutCh|kh|kw|InCh + conv_param->input_channel_ = x_tensor->shape().at(kNHWC_C); + conv_param->output_channel_ = dy_tensor->shape().at(kNHWC_C); + // TBD + conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; + conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; int ws_size = conv_param->output_h_ * conv_param->output_w_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; @@ -67,34 +61,21 @@ int ConvolutionGradFilterCPUKernel::Init() { return RET_ERROR; } - int output_w = 0; - int output_h = 0; - output_h = dy->shape()[kNHWC_H]; - output_w = dy->shape()[kNHWC_W]; - - std::vector out_shape(4); - out_shape.at(0) = conv_param->output_channel_; - out_shape.at(1) = conv_param->kernel_h_; - out_shape.at(2) = conv_param->kernel_w_; - out_shape.at(3) = conv_param->input_channel_ / conv_param->group_; - - // weight is output - weight_tensor->set_shape(out_shape); - weight_tensor->set_data_type(input_tensor->data_type()); - - conv_param->output_h_ = output_h; - conv_param->output_w_ = output_w; - return RET_OK; } -int ConvolutionGradFilterCPUKernel::ReSize() { return 0; } +int ConvolutionGradFilterCPUKernel::ReSize() { return RET_OK; } int ConvolutionGradFilterCPUKernel::Run() { - auto conv_param = reinterpret_cast(opParameter); - auto *input_dy = inputs_.at(0); - auto *input_x = inputs_.at(1); - auto *out_dw = outputs_.at(0); + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } + auto conv_param = reinterpret_cast(op_parameter_); + auto *input_dy = in_tensors_.at(0); + auto *input_x = in_tensors_.at(1); + auto *out_dw = out_tensors_.at(0); auto x_addr = reinterpret_cast(input_x->Data()); auto dy_addr = reinterpret_cast(input_dy->Data()); @@ -135,7 +116,48 @@ int ConvolutionGradFilterCPUKernel::Run() { // std::cout << "run succ" << std::endl; return RET_OK; } +#if 0 +OpParameter *PopulateConvolutionGradFilterParameter(const lite::Primitive *primitive) { + ConvParameter *param = new (std::nothrow) ConvParameter(); + if (param == nullptr) { + MS_LOG(ERROR) << "new Param for conv grad filter failed."; + return nullptr; + } + param->op_parameter_.type_ = primitive->Type(); + + auto convg_primitive = primitive->Value()->value_as_Conv2DGradFilter(); + param->kernel_h_ = convg_primitive->kernelH(); + param->kernel_w_ = convg_primitive->kernelW(); + param->stride_h_ = convg_primitive->strideH(); + param->stride_w_ = convg_primitive->strideW(); + param->dilation_h_ = convg_primitive->dilateH(); + param->dilation_w_ = convg_primitive->dilateW(); + param->pad_h_ = convg_primitive->padUp(); + param->pad_w_ = convg_primitive->padLeft(); + param->pad_u_ = convg_primitive->padUp(); + param->pad_d_ = convg_primitive->padDown(); + param->pad_l_ = convg_primitive->padLeft(); + param->pad_r_ = convg_primitive->padRight(); + param->group_ = convg_primitive->group(); + auto act_type = convg_primitive->activationType(); + switch (act_type) { + case schema::ActivationType_RELU: + param->is_relu_ = true; + param->is_relu6_ = false; + break; + case schema::ActivationType_RELU6: + param->is_relu_ = false; + param->is_relu6_ = true; + break; + default: + param->is_relu_ = false; + param->is_relu6_ = false; + break; + } + return reinterpret_cast(param); +} +#endif kernel::LiteKernel *CpuConvGradFilterFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h index 75a345b6fb..efc1cb6604 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h @@ -1,4 +1,4 @@ -/** + /** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,15 +28,17 @@ class ConvolutionGradFilterCPUKernel : public LiteKernel { const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} - ~ConvolutionGradFilterCPUKernel() override { delete workspace; } + ~ConvolutionGradFilterCPUKernel() override { delete [] workspace; } int Init() override; int ReSize() override; int Run() override; private: - float *workspace; + float *workspace = nullptr; }; + + } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc index c067053c41..0d2a4faf25 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc @@ -29,23 +29,14 @@ using mindspore::schema::PrimitiveType_Conv2DGradInput; namespace mindspore::kernel { int ConvolutionGradInputCPUKernel::Init() { - if (2 != this->inputs_.size()) { - MS_LOG(ERROR) << "Conv2d Grad should has 2 inputs"; - return RET_ERROR; - } - if (1 != this->outputs_.size()) { - MS_LOG(ERROR) << "Conv2d Grad should has one output"; - return RET_ERROR; - } - - auto *dy_tensor = inputs_.at(kInputIndex); + auto *dy_tensor = in_tensors_.at(kInputIndex); MS_ASSERT(dy_tensor != nullptr); - auto *weight_tensor = inputs_.at(kWeightIndex); + auto *weight_tensor = in_tensors_.at(kWeightIndex); MS_ASSERT(weight_tensor != nullptr); - auto *dx_tensor = outputs_.at(kOutputIndex); + auto *dx_tensor = out_tensors_.at(kOutputIndex); MS_ASSERT(dx_tensor != nullptr); - auto conv_param = reinterpret_cast(opParameter); + auto conv_param = reinterpret_cast(op_parameter_); conv_param->output_batch_ = dx_tensor->shape()[(kNHWC_N)]; conv_param->input_batch_ = dy_tensor->shape()[(kNHWC_N)]; @@ -74,10 +65,16 @@ int ConvolutionGradInputCPUKernel::Init() { int ConvolutionGradInputCPUKernel::ReSize() { return 0; } int ConvolutionGradInputCPUKernel::Run() { - auto conv_param = reinterpret_cast(opParameter); - auto *input_dy = inputs_.at(0); - auto *input_w = inputs_.at(1); - auto *out_dx = outputs_.at(0); + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } + + auto conv_param = reinterpret_cast(op_parameter_); + auto *input_dy = in_tensors_.at(0); + auto *input_w = in_tensors_.at(1); + auto *out_dx = out_tensors_.at(0); auto dy_addr = reinterpret_cast(input_dy->Data()); auto w_addr = reinterpret_cast(input_w->Data()); @@ -116,6 +113,49 @@ int ConvolutionGradInputCPUKernel::Run() { return 0; } +#if 0 +OpParameter *PopulateConvolutionGradInputParameter(const lite::Primitive *primitive) { + ConvParameter *param = new (std::nothrow) ConvParameter(); + if (param == nullptr) { + MS_LOG(ERROR) << "new Param for conv grad input failed."; + return nullptr; + } + param->op_parameter_.type_ = primitive->Type(); + + auto convg_primitive = primitive->Value()->value_as_Conv2DGradInput(); + param->kernel_h_ = convg_primitive->kernelH(); + param->kernel_w_ = convg_primitive->kernelW(); + param->stride_h_ = convg_primitive->strideH(); + param->stride_w_ = convg_primitive->strideW(); + param->dilation_h_ = convg_primitive->dilateH(); + param->dilation_w_ = convg_primitive->dilateW(); + param->pad_h_ = convg_primitive->padUp(); + param->pad_w_ = convg_primitive->padLeft(); + param->pad_u_ = convg_primitive->padUp(); + param->pad_d_ = convg_primitive->padDown(); + param->pad_l_ = convg_primitive->padLeft(); + param->pad_r_ = convg_primitive->padRight(); + param->group_ = convg_primitive->group(); + auto act_type = convg_primitive->activationType(); + switch (act_type) { + case schema::ActivationType_RELU: + param->is_relu_ = true; + param->is_relu6_ = false; + break; + case schema::ActivationType_RELU6: + param->is_relu_ = false; + param->is_relu6_ = true; + break; + default: + param->is_relu_ = false; + param->is_relu6_ = false; + break; + } + + return reinterpret_cast(param); +} +#endif + kernel::LiteKernel *CpuConvGradInputFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h index 6bda66b3dd..93930fded8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h @@ -28,7 +28,7 @@ class ConvolutionGradInputCPUKernel : public LiteKernel { const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} - ~ConvolutionGradInputCPUKernel() override { delete workspace; } + ~ConvolutionGradInputCPUKernel() override { delete [] workspace; } int Init() override; int ReSize() override; @@ -37,6 +37,9 @@ class ConvolutionGradInputCPUKernel : public LiteKernel { private: float *workspace; }; + +// OpParameter *PopulateConvolutionGradInputParameter(const lite::Primitive *primitive); + } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/depend.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/depend.cc new file mode 100644 index 0000000000..a9bc417462 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/depend.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/runtime/kernel/arm/fp32_grad/depend.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Depend; + +namespace mindspore::kernel { + +int DependCPUKernel::Init() { + return RET_OK; +} + +int DependCPUKernel::ReSize() { return 0; } + +int DependCPUKernel::Run() { +#if 0 + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } + auto in = reinterpret_cast(in_tensors_.at(0)->Data()); + auto out = reinterpret_cast(out_tensors_.at(0)->Data()); + + memcpy(out, in, in_tensors_.at(0)->Size()); +#endif + return RET_OK; +} + +kernel::LiteKernel *CpuDependFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, const lite::PrimitiveC *primitive) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Depend); + auto *kernel = + new (std::nothrow) DependCPUKernel(opParameter, inputs, outputs, ctx, primitive); + MS_ASSERT(kernel != nullptr); + + auto ret = kernel->Init(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Depend, CpuDependFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/depend.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/depend.h new file mode 100644 index 0000000000..2b222ecbaf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/depend.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_DEPEND_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_DEPEND_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" + +#include "nnacl/fp32/arithmetic.h" + +namespace mindspore::kernel { +class DependCPUKernel : public LiteKernel { + public: + explicit DependCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx, + const lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + param = parameter; + } + ~DependCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + OpParameter *param; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_DEPEND_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/make_tuple.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/make_tuple.h new file mode 100644 index 0000000000..bbbc22f902 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/make_tuple.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_MAKE_TUPLE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_MAKE_TUPLE_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" + +#include "src/runtime/kernel/arm/nnacl/fp32/arithmetic.h" + +namespace mindspore::kernel { +class MakeTupleCPUKernel : public LiteKernel { + public: + explicit MakeTupleCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + param = parameter; + } + ~MakeTupleCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + OpParameter *param; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_MAKE_TUPLE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/opt_momentum.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/opt_momentum.cc deleted file mode 100644 index 4d332b8f84..0000000000 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/opt_momentum.cc +++ /dev/null @@ -1,87 +0,0 @@ - -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "schema/model_generated.h" -#include "src/kernel_registry.h" -#include "src/runtime/kernel/arm/fp32_grad/opt_momentum.h" -#include "include/errorcode.h" - -using mindspore::kernel::KERNEL_ARCH::kCPU; -using mindspore::lite::KernelRegistrar; -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; -using mindspore::schema::PrimitiveType_OptMomentum; - -namespace mindspore::kernel { - -int OptMomentumCPUKernel::ReSize() { return 0; } - -int OptMomentumCPUKernel::Run() { - auto prepare_ret = Prepare(); - if (prepare_ret != RET_OK) { - MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; - return prepare_ret; - } - if (inputs_.size() != 5 || !outputs_.empty()) { - MS_LOG(ERROR) << "OptMomentumCPUKernel error input output size!"; - return RET_ERROR; - } - - if (inputs_[0]->ElementsNum() != inputs_[1]->ElementsNum() || - inputs_[0]->ElementsNum() != inputs_[3]->ElementsNum()) { - MS_LOG(ERROR) << "error input data size!"; - return RET_ERROR; - } - auto weight = reinterpret_cast(inputs_[0]->Data()); - auto accumulate = reinterpret_cast(inputs_[1]->Data()); - float learning_rate = reinterpret_cast(inputs_[2]->Data())[0]; - auto gradient = reinterpret_cast(inputs_[3]->Data()); - float moment = reinterpret_cast(inputs_[4]->Data())[0]; - size_t elem_num = inputs_[0]->ElementsNum(); - for (size_t i = 0; i < elem_num; ++i) { - accumulate[i] = accumulate[i] * moment + gradient[i]; - weight[i] -= accumulate[i] * learning_rate; - } - return RET_OK; -} - -int OptMomentumCPUKernel::Init() { return 0; } - -kernel::LiteKernel *CpuOptMomentumFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc, - const mindspore::lite::PrimitiveC *primitive) { - MS_ASSERT(desc.type == schema::PrimitiveType_OptMomentum); - auto *kernel = new (std::nothrow) OptMomentumCPUKernel(opParameter, inputs, outputs, ctx, primitive); - if (kernel == nullptr) { - MS_LOG(ERROR) << "new OptMomentumCPUKernel fail!"; - return nullptr; - } - - auto ret = kernel->Init(); - if (0 != ret) { - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - delete kernel; - return nullptr; - } - return kernel; -} - -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_OptMomentum, CpuOptMomentumFp32KernelCreator) -} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc index 3082b1bc43..9c6c4b90fd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc @@ -20,6 +20,7 @@ #include "nnacl/fp32/pooling.h" #include "nnacl/fp32_grad/pooling_grad.h" #include "include/errorcode.h" +// #include "src/train/ops/train_ops.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -29,9 +30,15 @@ using mindspore::schema::PrimitiveType_PoolingGrad; namespace mindspore::kernel { int PoolingGradCPUKernel::Init() { - PoolingParameter *pool_param = reinterpret_cast(opParameter); + PoolingParameter *pool_param = reinterpret_cast(op_parameter_); - auto in_shape = inputs_.at(0)->shape(); + auto in_shape = in_tensors_.at(0)->shape(); + auto out_shape = in_tensors_.at(1)->shape(); + + if (pool_param->pool_mode_ == PoolMode_AvgPool) { + in_shape = in_tensors_.at(1)->shape(); + out_shape = in_tensors_.at(0)->shape(); + } int input_h = in_shape.at(1); int input_w = in_shape.at(2); @@ -40,25 +47,39 @@ int PoolingGradCPUKernel::Init() { pool_param->window_h_ = input_h; } + pool_param->input_h_ = in_shape[kNHWC_H]; + pool_param->input_w_ = in_shape[kNHWC_W]; + pool_param->input_batch_ = in_shape[kNHWC_N]; + pool_param->input_channel_ = in_shape[kNHWC_C]; + // Emir -- here I assume we get the outputshape in the output tensor - auto *out_tensor = outputs_.front(); - auto out_shape = out_tensor->shape(); + // auto *out_tensor = out_tensors_.front(); + // auto out_shape = in_tensors_.at(1)->shape(); + + pool_param->output_h_ = out_shape[kNHWC_H]; + pool_param->output_w_ = out_shape[kNHWC_W]; + pool_param->output_batch_ = out_shape[kNHWC_N]; + pool_param->output_channel_ = out_shape[kNHWC_C]; - out_tensor->set_shape(out_shape); - out_tensor->set_data_type(inputs_.at(0)->data_type()); return RET_OK; } int PoolingGradCPUKernel::ReSize() { return RET_OK; } int PoolingGradCPUKernel::Run() { - PoolingParameter *pool_param = reinterpret_cast(opParameter); - auto input_ptr = reinterpret_cast(inputs_.at(0)->Data()); - auto output_ptr = reinterpret_cast(outputs_.at(0)->Data()); + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } + PoolingParameter *pool_param = reinterpret_cast(op_parameter_); + auto input_ptr = reinterpret_cast(in_tensors_.at(0)->Data()); + auto output_ptr = reinterpret_cast(out_tensors_.at(0)->Data()); if (pool_param->pool_mode_ == PoolMode_MaxPool) { - auto ind = reinterpret_cast(inputs_.at(1)->Data()); - MaxPoolingGrad(input_ptr, ind, output_ptr, pool_param); + auto dx_ptr = reinterpret_cast(in_tensors_.at(1)->Data()); + auto dy_ptr = reinterpret_cast(in_tensors_.at(2)->Data()); + MaxPoolingGrad(input_ptr, dx_ptr, dy_ptr, output_ptr, pool_param); } else { AvgPoolingGrad(input_ptr, output_ptr, pool_param); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.h index 980aa5f8b9..7b26658dce 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.h @@ -43,6 +43,7 @@ class PoolingGradCPUKernel : public LiteKernel { private: uint8_t data_shape_{0}; }; + } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POOLING_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.cc index 5127acb8cb..0b5ba98281 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.cc @@ -31,10 +31,10 @@ int PowerGradCPUKernel::Init() { return RET_OK; } int PowerGradCPUKernel::ReSize() { return RET_OK; } int PowerGradCPUKernel::Run() { - auto dy_addr = reinterpret_cast(inputs_.at(0)->Data()); - auto x_addr = reinterpret_cast(inputs_.at(1)->Data()); - auto dx_addr = reinterpret_cast(outputs_.at(0)->Data()); - auto size = inputs_.at(0)->ElementsNum(); + auto dy_addr = reinterpret_cast(in_tensors_.at(0)->Data()); + auto x_addr = reinterpret_cast(in_tensors_.at(1)->Data()); + auto dx_addr = reinterpret_cast(out_tensors_.at(0)->Data()); + auto size = in_tensors_.at(0)->ElementsNum(); float exp = power_ - 1; Power(x_addr, &exp, dx_addr, size, scale_, shift_, true); @@ -47,6 +47,7 @@ int PowerGradCPUKernel::Run() { return RET_OK; } + kernel::LiteKernel *CpuPowerGradFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.h index 60f406d04f..5d4104d5e3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.h @@ -45,6 +45,7 @@ class PowerGradCPUKernel : public LiteKernel { float scale_; float shift_; }; + } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_POWER_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc index 315171efcc..2b71ad2e70 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include "src/kernel_registry.h" #include "nnacl/softmax_parameter.h" #include "nnacl/fp32/softmax.h" @@ -46,9 +47,10 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int output[0] = total_loss / param->batch_size_; } -void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *labels, const float *losses, +void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *labels, const float *losses, float *grads, float *output) const { size_t row_start = 0; + float total_loss = 0; for (int i = 0; i < param->batch_size_; ++i) { if (labels[i] < 0) { MS_LOG(EXCEPTION) << "label value must >= 0"; @@ -56,78 +58,88 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *la size_t label = labels[i]; if (label > param->number_of_classes_) { MS_LOG(EXCEPTION) << "error label input!"; - } - for (size_t j = 0; j < param->number_of_classes_; ++j) { - size_t index = row_start + j; - if (j == label) { - output[index] = (losses[index] - 1) / param->batch_size_; - } else { - output[index] = losses[index] / param->batch_size_; + } else { + total_loss -= logf(losses[i * param->number_of_classes_ + label]); + for (size_t j = 0; j < param->number_of_classes_; ++j) { + size_t index = row_start + j; + if (j == label) { + grads[index] = (losses[index] - 1) / param->batch_size_; + } else { + grads[index] = losses[index] / param->batch_size_; + } } } row_start += param->number_of_classes_; } + output[0] = total_loss / param->batch_size_; } int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { - auto ins = reinterpret_cast(inputs_.at(0)->Data()); - auto labels = reinterpret_cast(inputs_.at(1)->Data()); - auto out = reinterpret_cast(outputs_.at(1)->Data()); + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return ret; + } + + auto ins = reinterpret_cast(in_tensors_.at(0)->Data()); + auto labels = reinterpret_cast(in_tensors_.at(1)->Data()); + float *out = reinterpret_cast(out_tensors_.at(0)->Data()); float *grads = NULL; - if (is_train()) { // outputs_.size() > 1) - grads = reinterpret_cast(outputs_.at(0)->Data()); + if (is_train() && out_tensors_.size() > 1) { + grads = reinterpret_cast(out_tensors_.at(1)->Data()); } - size_t data_size = inputs_.at(0)->ElementsNum(); + size_t data_size = in_tensors_.at(0)->ElementsNum(); float *losses = new (std::nothrow) float[data_size]; if (losses == nullptr) { MS_LOG(ERROR) << "losses is null"; - return nullptr; + return RET_ERROR; } - std::fill(losses, losses + data_size, 0); - MS_ASSERT(out != nullptr); MS_ASSERT(labels != nullptr); MS_ASSERT(ins != nullptr); - - SoftmaxParameter sm_params; - sm_params.n_dim_ = param->n_dim_; - sm_params.element_size_ = data_size; - sm_params.axis_ = 0; - for (int i = 0; i < 4; i++) // softmax has only 4 params in shape - sm_params.input_shape_[i] = param->input_shape_[i]; - float sum_data[sm_params.input_shape_[sm_params.axis_]] = {0}; - std::fill(sum_data, sum_data + sm_params.input_shape_[sm_params.axis_], 0); - Softmax(ins, losses, sum_data, &sm_params); - + std::fill(losses_, losses_ + data_size, 0); + std::fill(sum_data_, sum_data_ + sm_params_.input_shape_[0], 0); + Softmax(ins, losses_, sum_data_, &sm_params_); if (is_train()) { - GradPostExecute(labels, losses, grads); - } else { - ForwardPostExecute(labels, losses, out); + GradPostExecute(labels, losses_, grads, out); + } else if (out != nullptr) { + ForwardPostExecute(labels, losses_, out); } return RET_OK; } int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Init() { - if (context_->infer_shape_interrupt_ && !context_->running_) { - SetNeedReInit(); - return RET_OK; - } - auto dims = inputs_[0]->shape(); + // if (context_ && context_->infer_shape_interrupt_ && !context_->running_) { + // set_need_reinit(); + // return RET_OK; + // } + auto dims = in_tensors_[0]->shape(); param->n_dim_ = 2; param->number_of_classes_ = dims[1]; param->batch_size_ = dims[0]; for (unsigned int i = 0; i < dims.size(); i++) param->input_shape_[i] = dims[i]; - if (2 != this->inputs_.size()) { + if (2 != this->in_tensors_.size()) { MS_LOG(ERROR) << "softmax entropy loss should have two inputs"; return RET_ERROR; } - auto *in0 = inputs_.front(); + auto *in0 = in_tensors_.front(); if (in0 == nullptr) { MS_LOG(ERROR) << "softmax etropy loss in0 have no data"; return RET_ERROR; } + size_t data_size = in_tensors_.at(0)->ElementsNum(); + losses_ = new (std::nothrow) float[data_size]; + sum_data_ = new (std::nothrow) float[dims[0]]; + MS_ASSERT(losses_ != nullptr); + MS_ASSERT(sum_data_ != nullptr); + + sm_params_.n_dim_ = 2; + sm_params_.element_size_ = data_size; + sm_params_.axis_ = 1; + for (int i = 0; i < dims.size(); i++) sm_params_.input_shape_[i] = dims[i]; + return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h index 1e243ef542..d961f06960 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h @@ -14,31 +14,32 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ #include -#include "src/lite_kernel.h" +#include "src/train/loss_kernel.h" #include "ir/anf.h" #include "nnacl/fp32_grad/softmax_grad.h" #include "nnacl/fp32/arithmetic.h" +#include "nnacl/softmax_parameter.h" namespace mindspore::kernel { -class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LiteKernel { +class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LossKernel { public: explicit SparseSoftmaxCrossEntropyWithLogitsCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + : LossKernel(parameter, inputs, outputs, ctx, primitive) { param = reinterpret_cast(parameter); } - ~SparseSoftmaxCrossEntropyWithLogitsCPUKernel() override = default; + ~SparseSoftmaxCrossEntropyWithLogitsCPUKernel() override { delete[] losses_; delete[] sum_data_; } void ForwardPostExecute(const int *labels, const float *losses, float *output) const; - void GradPostExecute(const int *labels, const float *losses, float *output) const; + void GradPostExecute(const int *labels, const float *losses, float* grads, float *output) const; int Init() override; int ReSize() override; @@ -46,7 +47,11 @@ class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LiteKernel { private: SoftmaxCrossEntropyParameter *param; + SoftmaxParameter sm_params_; + float *losses_ = nullptr; + float *sum_data_ = nullptr; }; + } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/tuple_getitem.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/tuple_getitem.cc new file mode 100644 index 0000000000..34b78f31b8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/tuple_getitem.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/runtime/kernel/arm/fp32_grad/tuple_getitem.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_TupleGetItem; + +namespace mindspore::kernel { + +int TupleGetItemCPUKernel::Init() { + return RET_OK; +} + +int TupleGetItemCPUKernel::ReSize() { return 0; } + +int TupleGetItemCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } + auto in = reinterpret_cast(in_tensors_.at(0)->Data()); + auto out = reinterpret_cast(out_tensors_.at(0)->Data()); + + memcpy(out, in, in_tensors_.at(0)->Size()); + + return RET_OK; +} + +kernel::LiteKernel *CpuTupleGetItemFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, const lite::PrimitiveC *primitive) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_TupleGetItem); + auto *kernel = + new (std::nothrow) TupleGetItemCPUKernel(opParameter, inputs, outputs, ctx, primitive); + MS_ASSERT(kernel != nullptr); + + auto ret = kernel->Init(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TupleGetItem, CpuTupleGetItemFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/tuple_getitem.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/tuple_getitem.h new file mode 100644 index 0000000000..27100ecaaf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/tuple_getitem.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_TUPLE_GETITEM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_TUPLE_GETITEM_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" + +#include "nnacl/fp32/arithmetic.h" + +namespace mindspore::kernel { +class TupleGetItemCPUKernel : public LiteKernel { + public: + explicit TupleGetItemCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx, + const lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + param = parameter; + } + ~TupleGetItemCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + OpParameter *param; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_TUPLE_GETITEM_H_ diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 925dd52c72..d7d6098ff1 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -94,8 +94,10 @@ int Scheduler::InferShape(const lite::Model *model, std::vectorat(size_t(inIndexes->GetAs(j)))); } auto outIndexes = cNode->outputIndex(); - for (size_t j = 0; j < outIndexes->size(); j++) { - outputs.emplace_back(tensors->at(size_t(outIndexes->GetAs(j)))); + if (outIndexes != nullptr) { + for (size_t j = 0; j < outIndexes->size(); j++) { + outputs.emplace_back(tensors->at(size_t(outIndexes->GetAs(j)))); + } } auto *primitive = model->GetOp(cNode->name()->str()); if (primitive == nullptr) { diff --git a/mindspore/lite/src/train/loss_kernel.h b/mindspore/lite/src/train/loss_kernel.h new file mode 100644 index 0000000000..70ac705771 --- /dev/null +++ b/mindspore/lite/src/train/loss_kernel.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_TRAIN_LOSS_KERNEL_H_ +#define MINDSPORE_LITE_SRC_TRAIN_LOSS_KERNEL_H_ +#include +#include "src/lite_kernel.h" +namespace mindspore::kernel { + +class LossKernel : public LiteKernel { + public: + LossKernel() = default; + explicit LossKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, + const lite::Context *ctx, + const lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + ~LossKernel() = default; +}; + +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_TRAIN_LOSS_KERNEL_H_ diff --git a/mindspore/lite/src/train/train_populate_parameter.cc b/mindspore/lite/src/train/train_populate_parameter.cc new file mode 100644 index 0000000000..b8725f4fe2 --- /dev/null +++ b/mindspore/lite/src/train/train_populate_parameter.cc @@ -0,0 +1,250 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/populate_parameter.h" +#include "src/train/train_populate_parameter.h" +#include "src/ops/pooling_grad.h" +#include "nnacl/pooling_parameter.h" +#include "src/ops/softmax_cross_entropy.h" +#include "nnacl/fp32_grad/softmax_grad.h" +#include "src/ops/activation_grad.h" +#include "nnacl/fp32/activation.h" +#include "src/ops/conv2d_grad_filter.h" +#include "src/ops/conv2d_grad_input.h" +#include "nnacl/conv_parameter.h" +#include "src/ops/power_grad.h" +#include "nnacl/power_parameter.h" + +namespace mindspore::kernel { + +OpParameter *DefaultPopulateParameter(const mindspore::lite::PrimitiveC *primitive) { + if (primitive == nullptr) { + MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; + return nullptr; + } + + OpParameter *param = new (std::nothrow) OpParameter(); + if (param == nullptr) { + MS_LOG(ERROR) << "new Param for primitive failed."; + return nullptr; + } + + param->type_ = primitive->Type(); + return param; +} + +OpParameter *PopulateSoftmaxCrossEntropyParameter(const mindspore::lite::PrimitiveC *primitive) { + if (primitive == nullptr) { + MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; + return nullptr; + } + SoftmaxCrossEntropyParameter *sce_param = new (std::nothrow) SoftmaxCrossEntropyParameter(); + if (sce_param == nullptr) { + MS_LOG(ERROR) << "new SoftmaxCrossEntropyParameter failed."; + return nullptr; + } + sce_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(sce_param); +} + +OpParameter *PopulatePoolingGradParameter(const mindspore::lite::PrimitiveC *primitive) { + if (primitive == nullptr) { + MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; + return nullptr; + } + PoolingParameter *pooling_param = new (std::nothrow) PoolingParameter(); + if (pooling_param == nullptr) { + MS_LOG(ERROR) << "new PoolingParameter failed."; + return nullptr; + } + pooling_param->op_parameter_.type_ = primitive->Type(); + auto pooling_primitive = + reinterpret_cast(const_cast(primitive)); + + pooling_param->global_ = pooling_primitive->GetGlobal(); + pooling_param->window_w_ = pooling_primitive->GetWindowW(); + pooling_param->window_h_ = pooling_primitive->GetWindowH(); + + pooling_param->pad_u_ = pooling_primitive->GetPadUp(); + pooling_param->pad_d_ = pooling_primitive->GetPadDown(); + pooling_param->pad_l_ = pooling_primitive->GetPadLeft(); + pooling_param->pad_r_ = pooling_primitive->GetPadRight(); + pooling_param->stride_w_ = pooling_primitive->GetStrideW(); + pooling_param->stride_h_ = pooling_primitive->GetStrideH(); + + pooling_param->pool_mode_ = PoolMode_No; + pooling_param->round_mode_ = RoundMode_No; + + switch (pooling_primitive->GetPoolingMode()) { + case schema::PoolMode_MAX_POOLING: + pooling_param->pool_mode_ = PoolMode_MaxPool; + break; + case schema::PoolMode_MEAN_POOLING: + pooling_param->pool_mode_ = PoolMode_AvgPool; + break; + default: + break; + } + + switch (pooling_primitive->GetRoundMode()) { + case schema::RoundMode_FLOOR: + pooling_param->round_mode_ = RoundMode_Floor; + break; + case schema::RoundMode_CEIL: + pooling_param->round_mode_ = RoundMode_Ceil; + break; + default: + break; + } + return reinterpret_cast(pooling_param); +} + +OpParameter *PopulateActivationGradParameter(const mindspore::lite::PrimitiveC *primitive) { + if (primitive == nullptr) { + MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; + return nullptr; + } + + ActivationParameter *act_param = new (std::nothrow) ActivationParameter(); + if (act_param == nullptr) { + MS_LOG(ERROR) << "new ActivationParameter failed."; + return nullptr; + } + act_param->op_parameter_.type_ = primitive->Type(); + auto activation = + reinterpret_cast(const_cast(primitive)); + act_param->type_ = static_cast(activation->GetType()); + act_param->alpha_ = activation->GetAlpha(); + return reinterpret_cast(act_param); +} + +OpParameter *PopulateConvolutionGradFilterParameter(const mindspore::lite::PrimitiveC *primitive) { + if (primitive == nullptr) { + MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; + return nullptr; + } + + ConvParameter *param = new (std::nothrow) ConvParameter(); + if (param == nullptr) { + MS_LOG(ERROR) << "new Param for conv grad filter failed."; + return nullptr; + } + param->op_parameter_.type_ = primitive->Type(); + + auto convg_primitive = + reinterpret_cast(const_cast(primitive)); + param->kernel_h_ = convg_primitive->GetKernelH(); + param->kernel_w_ = convg_primitive->GetKernelW(); + param->stride_h_ = convg_primitive->GetStrideH(); + param->stride_w_ = convg_primitive->GetStrideW(); + param->dilation_h_ = convg_primitive->GetDilateH(); + param->dilation_w_ = convg_primitive->GetDilateW(); + param->pad_u_ = convg_primitive->GetPadUp(); + param->pad_d_ = convg_primitive->GetPadDown(); + param->pad_l_ = convg_primitive->GetPadLeft(); + param->pad_r_ = convg_primitive->GetPadRight(); + param->group_ = convg_primitive->GetGroup(); + param->act_type_ = ActType_No; + switch (convg_primitive->GetActivationType()) { + case schema::ActivationType_RELU: + param->act_type_ = ActType_Relu; + break; + case schema::ActivationType_RELU6: + param->act_type_ = ActType_Relu6; + break; + default: + break; + } + + return reinterpret_cast(param); +} + +OpParameter *PopulateConvolutionGradInputParameter(const mindspore::lite::PrimitiveC *primitive) { + if (primitive == nullptr) { + MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; + return nullptr; + } + + ConvParameter *param = new (std::nothrow) ConvParameter(); + if (param == nullptr) { + MS_LOG(ERROR) << "new Param for conv grad filter failed."; + return nullptr; + } + param->op_parameter_.type_ = primitive->Type(); + + auto convg_primitive = + reinterpret_cast(const_cast(primitive)); + param->kernel_h_ = convg_primitive->GetKernelH(); + param->kernel_w_ = convg_primitive->GetKernelW(); + param->stride_h_ = convg_primitive->GetStrideH(); + param->stride_w_ = convg_primitive->GetStrideW(); + param->dilation_h_ = convg_primitive->GetDilateH(); + param->dilation_w_ = convg_primitive->GetDilateW(); + param->pad_u_ = convg_primitive->GetPadUp(); + param->pad_d_ = convg_primitive->GetPadDown(); + param->pad_l_ = convg_primitive->GetPadLeft(); + param->pad_r_ = convg_primitive->GetPadRight(); + param->group_ = convg_primitive->GetGroup(); + param->act_type_ = ActType_No; + switch (convg_primitive->GetActivationType()) { + case schema::ActivationType_RELU: + param->act_type_ = ActType_Relu; + break; + case schema::ActivationType_RELU6: + param->act_type_ = ActType_Relu6; + break; + default: + break; + } + + return reinterpret_cast(param); +} + +OpParameter *PopulatePowerGradParameter(const mindspore::lite::PrimitiveC *primitive) { + if (primitive == nullptr) { + MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; + return nullptr; + } + + PowerParameter *power_param = new (std::nothrow) PowerParameter(); + if (power_param == nullptr) { + MS_LOG(ERROR) << "new PowerParameter failed."; + return nullptr; + } + power_param->op_parameter_.type_ = primitive->Type(); + auto power = reinterpret_cast(const_cast(primitive)); + power_param->power_ = power->GetPower(); + power_param->scale_ = power->GetScale(); + power_param->shift_ = power->GetShift(); + return reinterpret_cast(power_param); +} + +void PopulateTrainParameters() { + auto ppr = PopulateParameterRegistry::GetInstance(); + ppr->AddPopulateParameterFunc(schema::PrimitiveType_ApplyMomentum, DefaultPopulateParameter); + ppr->AddPopulateParameterFunc(schema::PrimitiveType_BiasGrad, PopulateArithmetic); + ppr->AddPopulateParameterFunc(schema::PrimitiveType_SoftmaxCrossEntropy, PopulateSoftmaxCrossEntropyParameter); + ppr->AddPopulateParameterFunc(schema::PrimitiveType_ActivationGrad, PopulateActivationGradParameter); + ppr->AddPopulateParameterFunc(schema::PrimitiveType_TupleGetItem, DefaultPopulateParameter); + ppr->AddPopulateParameterFunc(schema::PrimitiveType_Depend, DefaultPopulateParameter); + ppr->AddPopulateParameterFunc(schema::PrimitiveType_BNGrad, DefaultPopulateParameter); + ppr->AddPopulateParameterFunc(schema::PrimitiveType_Conv2DGradFilter, PopulateConvolutionGradFilterParameter); + ppr->AddPopulateParameterFunc(schema::PrimitiveType_Conv2DGradInput, PopulateConvolutionGradInputParameter); + ppr->AddPopulateParameterFunc(schema::PrimitiveType_PoolingGrad, PopulatePoolingGradParameter); + ppr->AddPopulateParameterFunc(schema::PrimitiveType_PowerGrad, PopulatePowerGradParameter); +} + +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/train/train_populate_parameter.h b/mindspore/lite/src/train/train_populate_parameter.h new file mode 100644 index 0000000000..3c187850f0 --- /dev/null +++ b/mindspore/lite/src/train/train_populate_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_POPULATE_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_POPULATE_PARAMETER_H_ + +#include "src/ops/primitive_c.h" + +namespace mindspore::kernel { + + void PopulateTrainParameters(); + + +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_POPULATE_PARAMETER_H_ diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc new file mode 100644 index 0000000000..408599209e --- /dev/null +++ b/mindspore/lite/src/train/train_session.cc @@ -0,0 +1,136 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/train_session.h" +#include +#include "utils/log_adapter.h" +#include "include/context.h" +#include "src/common/utils.h" +#include "mindspore/lite/src/ir/tensor.h" +#include "src/train/loss_kernel.h" +#include "src/train/train_populate_parameter.h" +#include "src/runtime/runtime_api.h" +#include "src/executor.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/fp32_grad/convolution.h" + +namespace mindspore::session { + +TrainSession::TrainSession() { kernel::PopulateTrainParameters(); } + +void TrainSession::ReplaceOps() { + mindspore::lite::KernelRegistrar tmp(mindspore::kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, + mindspore::schema::PrimitiveType_Conv2D, + mindspore::kernel::CpuConvTrainFp32KernelCreator); +} + +int TrainSession::CompileGraph(lite::Model *model) { + model_ = model; + ReplaceOps(); + return LiteSession::CompileGraph(model); +} + +void* TrainSession::ExportToBuf(void* buf, size_t *len) const { +// auto train_model_impl = (dynamic_cast(model_->model_impl())); +// return train_model_impl->ExportToBuf(buf, len); + return nullptr; +} + + +int TrainSession::RunGraph(const session::KernelCallBack &before, const session::KernelCallBack &after) { + auto ms_output_tensors = GetOutputs(); + this->outputs_.clear(); + for (auto ms_tensors : ms_output_tensors) + for (auto ms_tensor : ms_tensors.second) + this->outputs_.push_back((dynamic_cast(ms_tensor))->tensor()); + if (train_mode_) + return LiteSession::RunGraph(before, after); + + // object is expected to run only inference part of graph + // prepare a lit of kernels till the loss function -- temporary solution + std::vector infference_kernels; + for (auto kernel : this->kernels_) { + if (dynamic_cast(kernel) != nullptr) + break; + infference_kernels.push_back(kernel); + } + + MS_EXCEPTION_IF_NULL(this->context_); + // TODO(Emir) + // SetMaxWokerNum(context_->thread_num_); + // context_->running_ = true; + lite::Executor executor; + if (before == nullptr && after == nullptr) { + return executor.Run(this->inputs_, this->outputs_, infference_kernels, this->context_->allocator.get()); + } else { + return executor.Run(this->inputs_, this->outputs_, infference_kernels, this->context_->allocator.get(), + before, after); + } +} + +void TrainSession::train() { + for (auto *kernel : kernels_) { + MS_ASSERT(nullptr != kernel); + kernel->train(); + } + train_mode_ = true; + ext_output_map_.clear(); + for (auto kernel : this->kernels_) { + if (dynamic_cast(kernel) != nullptr) { + auto *ms_tensor = new lite::tensor::LiteTensor(kernel->out_tensors().at(0)); + ext_output_map_[kernel->name()].emplace_back(ms_tensor); + } + } +} + +void TrainSession::eval() { + for (auto *kernel : kernels_) { + MS_ASSERT(nullptr != kernel); + kernel->eval(); + } + train_mode_ = false; + kernel::LiteKernel* last_kernel = nullptr; + // We should get in_kernels and then get all last kernels + ext_output_map_ = output_node_map_; + for (auto kernel : this->kernels_) { + if ((dynamic_cast(kernel) != nullptr) && + (last_kernel != nullptr)) { + auto *ms_tensor = new lite::tensor::LiteTensor(last_kernel->out_tensors().at(0)); + ext_output_map_[last_kernel->name()].emplace_back(ms_tensor); + } + last_kernel = kernel; + } +} + +std::unordered_map> TrainSession::GetOutputs() const { + return ext_output_map_; +} +std::vector TrainSession::GetOutputsByName(const std::string &name) const { + auto ret_vect = LiteSession::GetOutputsByNodeName(name); // TODO(emir): GetOutputsByTensorName? + if (ret_vect.size() > 0) + return ret_vect; + auto ret = ext_output_map_.find(name); + if (ret == ext_output_map_.end()) { + MS_LOG(WARNING) << "Node " << name << " is not an output node"; + std::vector empty_ret; + return empty_ret; + } + return ret->second; +} + + + +} // namespace mindspore::session diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index d922ba61f0..f83e839823 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -259,6 +259,10 @@ endif() if (SUPPORT_TRAIN) set(TEST_LITE_SRC ${TEST_LITE_SRC} + # ${LITE_DIR}/src/train/ops/train_ops.cc + ${LITE_DIR}/src/train/train_populate_parameter.cc + ${LITE_DIR}/src/train/train_session.cc + ${LITE_DIR}/src/lite_session.cc # ${SRC_DIR}/common/trans.cc # ${SRC_DIR}/common/lite/trans_extends.cc # ${SRC_DIR}/kernel/kernel_build_info.cc diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc index 11a87e7366..de0fd2045d 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc @@ -25,9 +25,10 @@ #include "mindspore/lite/src/ir/tensor.h" #include "mindspore/lite/src/lite_kernel.h" #include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h" +#include "nnacl/fp32_grad/activation_grad.h" namespace mindspore { -class TestActGradFp32 : public mindspore::CommonTest { +class TestActGradFp32 : public mindspore::CommonTest { public: TestActGradFp32() {} }; @@ -41,13 +42,14 @@ TEST_F(TestActGradFp32, ReluGradFp32) { size_t input_size; std::string input_path = "./test_data/activationGrad/relu_y_50.bin"; auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + EXPECT_EQ(input_size, output_data_size * sizeof(float)); std::string yt_path = "./test_data/activationGrad/relu_yt_50.bin"; auto yt_data = reinterpret_cast(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); - + EXPECT_EQ(input_size, output_data_size * sizeof(float)); auto output_data = new float[output_data_size]; // warm up loop for (int i = 0; i < 3; i++) { - ReluGrad(yt_data, input_data, 50, output_data); + ReluGrad(yt_data, input_data, output_data_size, output_data); } int loop_count = 100; @@ -72,9 +74,9 @@ TEST_F(TestActGradFp32, ReluGradFp32) { EXPECT_EQ(res, 0); - delete input_data; + delete[] input_data; delete[] output_data; - delete yt_data; + delete[] yt_data; MS_LOG(INFO) << "ReluGradFp32 passed"; } @@ -118,9 +120,9 @@ TEST_F(TestActGradFp32, Relu6GradFp32) { EXPECT_EQ(res, 0); - delete input_data; + delete[] input_data; delete[] output_data; - delete yt_data; + delete[] yt_data; MS_LOG(INFO) << "Relu6GradFp32 passed"; } @@ -164,9 +166,9 @@ TEST_F(TestActGradFp32, LReluGradFp32) { EXPECT_EQ(res, 0); - delete input_data; + delete[] input_data; delete[] output_data; - delete yt_data; + delete[] yt_data; MS_LOG(INFO) << "LReluGradFp32 passed"; } @@ -211,9 +213,9 @@ TEST_F(TestActGradFp32, SigmoidGradFp32) { EXPECT_EQ(res, 0); // lite::CompareOutput(output_data, output_path); - delete input_data; + delete[] input_data; delete[] output_data; - delete yt_data; + delete[] yt_data; MS_LOG(INFO) << "SigmoidGradFp32 passed"; } @@ -257,9 +259,9 @@ TEST_F(TestActGradFp32, tanhGradFp32) { EXPECT_EQ(res, 0); - delete input_data; + delete[] input_data; delete[] output_data; - delete yt_data; + delete[] yt_data; MS_LOG(INFO) << "TanhGradFp32 passed"; } @@ -267,24 +269,25 @@ TEST_F(TestActGradFp32, hswishGradFp32) { // runtime part printf("Calculating runtime cost...\n"); uint64_t time_avg = 0; - size_t output_data_size = 50; + const size_t output_data_size = 10; size_t input_size; std::string input_path = "./test_data/activationGrad/hswish_x_50.bin"; auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + EXPECT_EQ(input_size, output_data_size * sizeof(float)); std::string yt_path = "./test_data/activationGrad/hswish_yt_50.bin"; auto yt_data = reinterpret_cast(mindspore::lite::ReadFile(yt_path.c_str(), &input_size)); - + EXPECT_EQ(input_size, output_data_size * sizeof(float)); auto output_data = new float[output_data_size]; // warm up loop for (int i = 0; i < 3; i++) { - HSwishGrad(yt_data, input_data, 50, output_data); + HSwishGrad(yt_data, input_data, static_cast(output_data_size), output_data); } int loop_count = 100; auto time_start = mindspore::lite::GetTimeUs(); for (int i = 0; i < loop_count; i++) { - HSwishGrad(yt_data, input_data, 50, output_data); + HSwishGrad(yt_data, input_data, output_data_size, output_data); } auto time_end = mindspore::lite::GetTimeUs(); auto cost = time_end - time_start; @@ -292,7 +295,7 @@ TEST_F(TestActGradFp32, hswishGradFp32) { printf("single thread running time : %f ms\n", time_avg / 1000.0f); printf("==================output data=================\n"); - for (int i = 0; i < 20; i++) { + for (int i = 0; i < std::min(output_data_size, 20UL); i++) { std::cout << output_data[i] << " ,"; } std::cout << std::endl; @@ -302,9 +305,9 @@ TEST_F(TestActGradFp32, hswishGradFp32) { EXPECT_EQ(res, 0); - delete input_data; + delete[] input_data; delete[] output_data; - delete yt_data; + delete[] yt_data; MS_LOG(INFO) << "hswishGradFp32 passed"; } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/arithmetic_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/arithmetic_grad_fp32_tests.cc index 1d317bbd2f..2bb74b022d 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/arithmetic_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/arithmetic_grad_fp32_tests.cc @@ -106,9 +106,14 @@ TEST_F(TestArithmeticGradFp32, TestAddGradFp32) { std::string dx2_path = "./test_data/operators/arithmetic_fp32_1_dx2_1_6.bin"; EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); - - for (int i = 0; i < 5; i++) delete all_tensors[i]; - delete param; + for (auto tensor : all_tensors) { + delete[] reinterpret_cast(tensor->Data()); + tensor->SetData(nullptr); + delete tensor; + } + // delete all_tensors; + // delete param; + delete kernel_obj; MS_LOG(INFO) << "TestAddGradFp32 passed"; } @@ -137,9 +142,14 @@ TEST_F(TestArithmeticGradFp32, TestAddGrad2Fp32) { std::string dx2_path = "./test_data/operators/arithmetic_fp32_1_dx2_1_6.bin"; EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); - - for (int i = 0; i < 5; i++) delete all_tensors[i]; - delete param; + for (auto tensor : all_tensors) { + delete[] reinterpret_cast(tensor->Data()); + tensor->SetData(nullptr); + delete tensor; + } + // for (int i = 0; i < 5; i++) delete all_tensors[i]; //TODO tensor data is unique pointer + // delete param; + delete kernel_obj; MS_LOG(INFO) << "TestAddGrad2Fp32 passed"; } @@ -169,8 +179,14 @@ TEST_F(TestArithmeticGradFp32, TestAddGrad3Fp32) { std::string dx2_path = "./test_data/operators/arithmetic_fp32_8_dx1_5_4_6.bin"; EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); - for (int i = 0; i < 5; i++) delete all_tensors[i]; - delete param; + for (auto tensor : all_tensors) { + delete[] reinterpret_cast(tensor->Data()); + tensor->SetData(nullptr); + delete tensor; + } + // for (int i = 0; i < 5; i++) delete all_tensors[i]; + // delete param; + delete kernel_obj; MS_LOG(INFO) << "TestAddGrad3Fp32 passed"; } @@ -200,8 +216,14 @@ TEST_F(TestArithmeticGradFp32, TestSubGradFp32) { std::string dx2_path = "./test_data/operators/arithmetic_fp32_2_dx2_1_6.bin"; EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); - for (int i = 0; i < 5; i++) delete all_tensors[i]; - delete param; + for (auto tensor : all_tensors) { + delete[] reinterpret_cast(tensor->Data()); + tensor->SetData(nullptr); + delete tensor; + } + // for (int i = 0; i < 5; i++) delete all_tensors[i]; + // delete param; + delete kernel_obj; MS_LOG(INFO) << "TestSubGradFp32 passed"; } @@ -231,8 +253,12 @@ TEST_F(TestArithmeticGradFp32, TestSubGrad2Fp32) { std::string dx2_path = "./test_data/operators/arithmetic_fp32_3_dx2_1_6.bin"; EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); - for (int i = 0; i < 5; i++) delete all_tensors[i]; - delete param; + for (auto tensor : all_tensors) { + delete[] reinterpret_cast(tensor->Data()); + tensor->SetData(nullptr); + delete tensor; + } + delete kernel_obj; MS_LOG(INFO) << "TestSubGrad2Fp32 passed"; } @@ -271,9 +297,13 @@ TEST_F(TestArithmeticGradFp32, TestMulGradFp32) { std::string dx2_path = "./test_data/operators/arithmetic_fp32_4_dx2_1_6.bin"; EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); - - for (int i = 0; i < 5; i++) delete all_tensors[i]; - delete param; + for (auto tensor : all_tensors) { + delete[] reinterpret_cast(tensor->Data()); + tensor->SetData(nullptr); + delete tensor; + } + delete kernel_obj; + // delete param; MS_LOG(INFO) << "TestMulGradFp32 passed"; } @@ -302,9 +332,14 @@ TEST_F(TestArithmeticGradFp32, TestMulGrad2Fp32) { std::string dx2_path = "./test_data/operators/arithmetic_fp32_4_dx2_1_6.bin"; EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); - - for (int i = 0; i < 5; i++) delete all_tensors[i]; - delete param; + for (auto tensor : all_tensors) { + delete[] reinterpret_cast(tensor->Data()); + tensor->SetData(nullptr); + delete tensor; + } + // for (int i = 0; i < 5; i++) delete all_tensors[i]; + // delete param; + delete kernel_obj; MS_LOG(INFO) << "TestMulGrad2Fp32 passed"; } @@ -333,9 +368,14 @@ TEST_F(TestArithmeticGradFp32, TestMulGrad3Fp32) { std::string dx2_path = "./test_data/operators/arithmetic_fp32_9_dx2_5_1_6.bin"; EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); - - for (int i = 0; i < 5; i++) delete all_tensors[i]; - delete param; + for (auto tensor : all_tensors) { + delete[] reinterpret_cast(tensor->Data()); + tensor->SetData(nullptr); + delete tensor; + } + // for (int i = 0; i < 5; i++) delete all_tensors[i]; + // delete param; + delete kernel_obj; MS_LOG(INFO) << "TestMulGrad3Fp32 passed"; } @@ -364,9 +404,14 @@ TEST_F(TestArithmeticGradFp32, TestMulGrad4Fp32) { std::string dx2_path = "./test_data/operators/arithmetic_fp32_9_dx2_5_1_6.bin"; EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); - - for (int i = 0; i < 5; i++) delete all_tensors[i]; - delete param; + for (auto tensor : all_tensors) { + delete[] reinterpret_cast(tensor->Data()); + tensor->SetData(nullptr); + delete tensor; + } + // for (int i = 0; i < 5; i++) delete all_tensors[i]; + // delete param; + delete kernel_obj; MS_LOG(INFO) << "TestMulGrad4Fp32 passed"; } @@ -395,9 +440,14 @@ TEST_F(TestArithmeticGradFp32, TestDivGradFp32) { std::string dx2_path = "./test_data/operators/arithmetic_fp32_5_dx2_1_6.bin"; EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, dx2_path)); - - for (int i = 0; i < 5; i++) delete all_tensors[i]; - delete param; + for (auto tensor : all_tensors) { + delete[] reinterpret_cast(tensor->Data()); + tensor->SetData(nullptr); + delete tensor; + } + // for (int i = 0; i < 5; i++) delete all_tensors[i]; + delete kernel_obj; + // delete param; MS_LOG(INFO) << "TestDivGradFp32 passed"; } @@ -427,8 +477,14 @@ TEST_F(TestArithmeticGradFp32, TestDivGrad2Fp32) { std::string output_path = "./test_data/operators/arithmetic_fp32_6_dx1_1_6.bin"; EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, output_path)); - for (int i = 0; i < 5; i++) delete all_tensors[i]; - delete param; + for (auto tensor : all_tensors) { + delete[] reinterpret_cast(tensor->Data()); + tensor->SetData(nullptr); + delete tensor; + } + // for (int i = 0; i < 5; i++) delete all_tensors[i]; + // delete param; + delete kernel_obj; MS_LOG(INFO) << "TestDivGrad2Fp32 passed"; } @@ -457,9 +513,14 @@ TEST_F(TestArithmeticGradFp32, TestDivGrad3Fp32) { std::string output_path = "./test_data/operators/arithmetic_fp32_10_dx2_5_1_6.bin"; EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, output_path)); - - for (int i = 0; i < 5; i++) delete all_tensors[i]; - delete param; + for (auto tensor : all_tensors) { + delete[] reinterpret_cast(tensor->Data()); + tensor->SetData(nullptr); + delete tensor; + } + // for (int i = 0; i < 5; i++) delete all_tensors[i]; + // delete param; + delete kernel_obj; MS_LOG(INFO) << "TestDivGrad3Fp32 passed"; } @@ -488,9 +549,12 @@ TEST_F(TestArithmeticGradFp32, Test3DDivGrad2Fp32) { std::string output_path = "./test_data/operators/arithmetic_fp32_7_dx2_1_1_6.bin"; EXPECT_EQ(0, lite::CompareRelativeOutput(output_ptr, output_path)); - - for (int i = 0; i < 5; i++) delete all_tensors[i]; - delete param; + for (auto tensor : all_tensors) { + delete[] reinterpret_cast(tensor->Data()); + tensor->SetData(nullptr); + delete tensor; + } + delete kernel_obj; MS_LOG(INFO) << "TestDivGrad2Fp32 passed"; } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bias_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bias_grad_fp32_tests.cc index 5c52b57ff7..241ebdfcdd 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bias_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bias_grad_fp32_tests.cc @@ -18,8 +18,8 @@ #include "utils/log_adapter.h" #include "common/common_test.h" #include "src/common/file_utils.h" -#include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.h" -#include "mindspore/lite/src/kernel_registry.h" +#include "src/runtime/kernel/arm/fp32_grad/bias_grad.h" +#include "src/kernel_registry.h" namespace mindspore { @@ -40,9 +40,8 @@ TEST_F(TestBiasGradFp32, BiasGradFp32) { dy_tensor.SetData(input_data); std::vector inputs = {&dy_tensor}; - auto output_data = new float[7]; - std::vector dim_dw({7}); + std::vector dim_dw = {7}; lite::tensor::Tensor dw_tensor(TypeId::kNumberTypeFloat32, dim_dw); dw_tensor.SetData(output_data); std::vector outputs = {&dw_tensor}; @@ -62,9 +61,12 @@ TEST_F(TestBiasGradFp32, BiasGradFp32) { std::string output_path = "./test_data/operators/biasgradfp32_1_db_7.bin"; lite::CompareOutput(output_data, output_path); - // delete input_data; - // delete[] output_data; - delete bias_param; + delete [] input_data; + delete[] output_data; + // delete bias_param; + dy_tensor.SetData(nullptr); + dw_tensor.SetData(nullptr); + delete kernel_obj; MS_LOG(INFO) << "BiasGradFp32 passed"; } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc new file mode 100644 index 0000000000..7a29177034 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "src/common/file_utils.h" +#include "src/common/file_utils_ext.h" +#include "src/runtime/kernel/arm/fp32_grad/bn_grad.h" +#include "nnacl/fp32_grad/batch_norm.h" +#include "src/kernel_registry.h" +# + +namespace mindspore { + +class TestBNGradFp32 : public mindspore::CommonTest { + public: + TestBNGradFp32() {} + lite::tensor::Tensor *CreateInTensor(std::string file_name, std::vector dim); +}; + +lite::tensor::Tensor *TestBNGradFp32::CreateInTensor(std::string file_name, std::vector dim) { + size_t input_size = 0; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(file_name.c_str(), &input_size)); + auto tensor = new lite::tensor::Tensor(TypeId::kNumberTypeFloat32, dim); + tensor->SetData(input_data); + EXPECT_EQ(input_size, tensor->Size()); + return tensor; +} + +TEST_F(TestBNGradFp32, BNGradFp32) { + // prepare stage + auto bn_param = new BNGradParameter(); + bn_param->epsilon_ = 0.00001; + bn_param->momentum_ = 0.1; + const int batch = 2; + const int channels = 3; + const int height = 4; + const int width = 5; + + auto dy_tensor = CreateInTensor("./test_data/bngrad/dy_2_4_5_3.bin", {batch, height, width, channels}); + auto x_tensor = CreateInTensor("./test_data/bngrad/input_x_2_4_5_3.bin", {batch, height, width, channels}); + auto scale_tensor = CreateInTensor("./test_data/bngrad/scale_3.bin", {1, 1, 1, channels}); + auto mean_tensor = CreateInTensor("./test_data/bngrad/save_mean_3.bin", {1, 1, 1, channels}); + auto var_tensor = CreateInTensor("././test_data/bngrad/save_var_3.bin", {1, 1, 1, channels}); + // prepare output tensors + lite::tensor::Tensor dx_tensor(TypeId::kNumberTypeFloat32, {batch, height, width, channels}); + dx_tensor.MallocData(); + lite::tensor::Tensor dscale_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); + dscale_tensor.MallocData(); + lite::tensor::Tensor dbias_tensor(TypeId::kNumberTypeFloat32, {1, 1, 1, channels}); + dbias_tensor.MallocData(); + + std::vector inputs = {dy_tensor, x_tensor, scale_tensor, mean_tensor, var_tensor}; + std::vector outputs = {&dx_tensor, &dscale_tensor, &dbias_tensor}; + + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_BNGrad}; + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(bn_param), NULL, desc, nullptr); + + for (int i = 0; i < 3; i++) { + kernel_obj->Run(); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + kernel_obj->Run(); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + auto time_avg = cost / loop_count; + std::cout << "single thread running time : " << time_avg << "us\n"; + std::cout << "==========dx==========\n"; + auto dx = reinterpret_cast(outputs[0]->Data()); + for (int i = 0; i < 7; i++) std::cout << dx[i] << " "; + std::cout << "\n=======dscale=======\n"; + auto dscale = reinterpret_cast(outputs[1]->Data()); + for (int i = 0; i < channels; i++) std::cout << dscale[i] << " "; + std::cout << "\n"; + int res = mindspore::lite::CompareRelativeOutput(dscale, "./test_data/bngrad/output_dscale_3.bin"); + EXPECT_EQ(res, 0); + std::cout << "==========dbias==========\n"; + auto dbias = reinterpret_cast(outputs[2]->Data()); + for (int i = 0; i < 3; i++) std::cout << dbias[i] << " "; + std::cout << "\n"; + res = mindspore::lite::CompareRelativeOutput(dscale, "./test_data/bngrad/output_dscale_3.bin"); + EXPECT_EQ(res, 0); + for (auto v : inputs) { + delete[] reinterpret_cast(v->Data()); + v->SetData(nullptr); + // delete v; + } + delete kernel_obj; + MS_LOG(INFO) << "BNGradFp32 passed"; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc index 76fe27358c..81c8abe2e8 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc @@ -21,6 +21,7 @@ #include "common/common_test.h" #include "src/common/file_utils.h" #include "src/common/file_utils_ext.h" +#include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.h" #include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h" #include "mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h" #include "mindspore/lite/nnacl/conv_parameter.h" @@ -130,11 +131,14 @@ TEST_F(TestConvolutionGradFp32, ConvFp32FilterGrad) { EXPECT_EQ(res, 0); - // delete input_data; - // delete dy_data; - // delete [] dw_data; + delete [] input_data; + delete [] dy_data; + delete [] dw_data; delete kernel; - delete conv_param; + // delete conv_param; + dw_tensor.SetData(nullptr); + x_tensor.SetData(nullptr); + dy_tensor.SetData(nullptr); MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; } @@ -193,9 +197,15 @@ TEST_F(TestConvolutionGradFp32, ConvFp32InputGrad) { std::string output_path = "./test_data/conv/convfp32_dx_1_28_28_3.bin"; auto res = lite::CompareRelativeOutput(dx_data, output_path); EXPECT_EQ(res, 0); - + delete [] dx_data; + delete [] w_data; + delete [] dy_data; + w_tensor.SetData(nullptr); + dy_tensor.SetData(nullptr); + dx_tensor.SetData(nullptr); delete kernel; - delete conv_param; + // delete conv_param; + MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; } @@ -254,11 +264,14 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupFilterGrad) { auto res = lite::CompareRelativeOutput(dw_data, output_path); EXPECT_EQ(res, 0); - // delete input_data; - // delete dy_data; - // delete [] dw_data; + delete [] input_data; + delete [] dy_data; + delete [] dw_data; + dw_tensor.SetData(nullptr); + x_tensor.SetData(nullptr); + dy_tensor.SetData(nullptr); delete kernel; - delete conv_param; + // delete conv_param; MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; } @@ -317,9 +330,15 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupInputGrad) { std::string output_path = "./test_data/conv/convfp32_dx_g3_1_28_28_3.bin"; auto res = lite::CompareRelativeOutput(dx_data, output_path); EXPECT_EQ(res, 0); + delete [] dx_data; + delete [] w_data; + delete [] dy_data; + dx_tensor.SetData(nullptr); + w_tensor.SetData(nullptr); + dy_tensor.SetData(nullptr); delete kernel; - delete conv_param; + // delete conv_param; MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; } @@ -378,11 +397,14 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationFilterGrad) { std::string output_path = "./test_data/conv/convfp32_dw_g3_d2_18_3_3_3.bin"; auto res = lite::CompareRelativeOutput(dw_data, output_path); EXPECT_EQ(res, 0); - // delete input_data; - // delete dy_data; - // delete [] dw_data; + delete [] input_data; + delete [] dy_data; + delete [] dw_data; + dw_tensor.SetData(nullptr); + dy_tensor.SetData(nullptr); + x_tensor.SetData(nullptr); delete kernel; - delete conv_param; + // delete conv_param; MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; } @@ -441,80 +463,93 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationInputGrad) { std::string output_path = "./test_data/conv/convfp32_dx_g3_d2_1_28_28_3.bin"; auto res = lite::CompareRelativeOutput(dx_data, output_path); EXPECT_EQ(res, 0); - + delete [] dx_data; + delete [] w_data; + delete [] dy_data; + dx_tensor.SetData(nullptr); + dy_tensor.SetData(nullptr); + w_tensor.SetData(nullptr); delete kernel; - delete conv_param; + // delete conv_param; MS_LOG(INFO) << "TestConvolutionGradFp32 Filter Grad passed"; } -// TEST_F(TestConvolutionGradFp32, ConvGroupDilation) { -// // prepare stage -// auto conv_param = new ConvParameter(); -// InitConvParamGroup3Dilation2FP32(conv_param); - -// size_t x_size; -// std::string x_path = "./test_data/conv/convfp32_x_g3_d2_1_28_28_3.bin"; -// auto x_data = reinterpret_cast(mindspore::lite::ReadFile(x_path.c_str(), &x_size)); -// std::vector dim_x({1, 28, 28, 3}); -// tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); -// x_tensor.SetData(x_data); - -// size_t w_size; -// std::string w_path = "./test_data/conv/convfp32_w_g3_d2_18_3_3_3.bin"; -// auto w_data = reinterpret_cast(mindspore::lite::ReadFile(w_path.c_str(), &w_size)); -// std::vector dim_w({18, 3, 3, 1}); -// tensor::Tensor w_tensor(TypeId::kNumberTypeFloat32, dim_w); -// w_tensor.SetData(w_data); - -// size_t output_data_size = -// conv_param->output_batch_ * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; -// auto y_data = new float[output_data_size]; -// std::vector dim_y({1, 26, 26, 18}); -// tensor::Tensor y_tensor(TypeId::kNumberTypeFloat32, dim_y); -// y_tensor.SetData(y_data); - -// std::vector inputs = {&x_tensor, &w_tensor}; -// std::vector outputs = {&y_tensor}; -// // runtime part - -// printf("Calculating runtime cost...\n"); -// uint64_t time_avg = 0; - -// lite::Context context; -// ; -// context.deviceCtx.type = lite::DT_CPU; -// context.threadNum = 1; - -// kernel::KernelKey desc = {kernel::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Conv2D}; -// auto creator = lite::KernelRegistry::GetInstance()->GetKernelCreator(desc); -// auto kernel = creator(inputs, outputs, (OpParameter *)conv_param, &context, desc); - -// kernel->train(); -// EXPECT_EQ(kernel->is_train(), 1); - -// // warm up loop -// for (int i = 0; i < 3; i++) { -// kernel->Run(); -// } - -// int loop_count = 100; -// auto time_start = mindspore::lite::GetTimeUs(); -// for (int i = 0; i < loop_count; i++) { -// kernel->Run(); -// } -// auto time_end = mindspore::lite::GetTimeUs(); -// auto cost = time_end - time_start; -// time_avg = cost / loop_count; -// printf("single thread running time : %f ms\n", time_avg / 1000.0f); - -// std::string output_path = "./test_data/conv/convfp32_y_g3_d2_1_26_26_18.bin"; -// auto res = lite::CompareRelativeOutput(y_data, output_path); -// EXPECT_EQ(res, 0); - -// delete kernel; -// delete conv_param; - -// MS_LOG(INFO) << "TestConvolutionFp32 Filter Grad passed"; -// } +TEST_F(TestConvolutionGradFp32, ConvGroupDilation) { + // prepare stage + auto conv_param = new ConvParameter(); + InitConvParamGroup3Dilation2FP32(conv_param); + + size_t x_size; + std::string x_path = "./test_data/conv/convfp32_x_g3_d2_1_28_28_3.bin"; + auto x_data = reinterpret_cast(mindspore::lite::ReadFile(x_path.c_str(), &x_size)); + std::vector dim_x({1, 28, 28, 3}); + lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); + x_tensor.SetData(x_data); + + size_t w_size; + std::string w_path = "./test_data/conv/convfp32_w_g3_d2_18_3_3_3.bin"; + auto w_data = reinterpret_cast(mindspore::lite::ReadFile(w_path.c_str(), &w_size)); + std::vector dim_w({18, 3, 3, 1}); + lite::tensor::Tensor w_tensor(TypeId::kNumberTypeFloat32, dim_w); + w_tensor.SetData(w_data); + + size_t output_data_size = + conv_param->output_batch_ * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + auto y_data = new float[output_data_size]; + std::vector dim_y({1, 26, 26, 18}); + lite::tensor::Tensor y_tensor(TypeId::kNumberTypeFloat32, dim_y); + y_tensor.SetData(y_data); + + std::vector inputs = {&x_tensor, &w_tensor}; + std::vector outputs = {&y_tensor}; + // runtime part + + printf("Calculating runtime cost...\n"); + uint64_t time_avg = 0; + + lite::Context context; + context.device_ctx_.type = lite::DT_CPU; + context.thread_num_ = 1; + + + auto *kernel = new mindspore::kernel::ConvolutionTrainCPUKernel(reinterpret_cast(conv_param), + inputs, outputs, &context, 0); + kernel->Init(); + // kernel::KernelKey desc = {kernel::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Conv2D}; + // auto creator = lite::KernelRegistry::GetInstance()->GetKernelCreator(desc); + // auto kernel = creator(inputs, outputs, (OpParameter *)conv_param, &context, desc); + + kernel->train(); + EXPECT_EQ(kernel->is_train(), 1); + + // warm up loop + for (int i = 0; i < 3; i++) { + kernel->Run(); + } + + int loop_count = 100; + auto time_start = mindspore::lite::GetTimeUs(); + for (int i = 0; i < loop_count; i++) { + kernel->Run(); + } + auto time_end = mindspore::lite::GetTimeUs(); + auto cost = time_end - time_start; + time_avg = cost / loop_count; + printf("single thread running time : %f ms\n", time_avg / 1000.0f); + + std::string output_path = "./test_data/conv/convfp32_y_g3_d2_1_26_26_18.bin"; + auto res = lite::CompareRelativeOutput(y_data, output_path); + EXPECT_EQ(res, 0); + + delete [] y_data; + delete [] x_data; + delete [] w_data; + x_tensor.SetData(nullptr); + y_tensor.SetData(nullptr); + w_tensor.SetData(nullptr); + delete kernel; + + MS_LOG(INFO) << "TestConvolutionFp32 Filter Grad passed"; +} } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc new file mode 100644 index 0000000000..d937c5ba96 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc @@ -0,0 +1,564 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mindspore/lite/schema/inner/model_generated.h" +#include "mindspore/lite/include/model.h" +#include "common/common_test.h" +#include "include/train_session.h" +// #include "include/lite_session.h" +#include "include/context.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/common/file_utils.h" +#include "src/common/file_utils_ext.h" + +namespace mindspore { +class NetworkTest : public mindspore::CommonTest { + public: + NetworkTest() {} +}; + + +// INPUT(0) +// V +// +-------------+ +// | ReLU | +// +-------------+ +// +---output(1) V +// | V V weights(2) <----+ +// | +-------------+ | +// | | MatMul | | +// | +-------------+ | +// | output(3) V | +// | V V weights(4)<-+ | +// | +-------------+ | | +// | | Bias | | | +// | +-------------+ | | +// | output(5) V | | +// | V V LABELS(6) | | +// | +-------------+ | | +// | | CrossEntropy| | | +// | +-------------+ | | +// | +-dy(7) V V------------------------->Loss (14) +// | | V | | +// | | +-------------+ | | +// | | | BiasGrad | | | +// | | +-------------+ | | +// | | V db(8) | | +// | | +--------Update---+ | +// | +-------+ | +// +------V V | +// +-------------+ | +// | MatMul | | +// +-------------+ | +// V dw(9) | +// +-----------Update-----+ + + +TEST_F(NetworkTest, tuning_layer) { + const int BATCH_SIZE = 32; + const int NUM_CLASSES = 10; + const int FEATURE_SIZE = 1000; + auto meta_graph = std::make_shared(); + meta_graph->name = "graph"; + // define nodes + { + auto node = std::make_unique(); + node->inputIndex = {0}; + node->outputIndex = {1}; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Activation; + auto primitive = new schema::ActivationT; + primitive->type = schema::ActivationType_RELU; + node->primitive->value.value = primitive; + node->name = "ReLU"; + meta_graph->nodes.emplace_back(std::move(node)); + } + { + auto node = std::make_unique(); + node->inputIndex = {1, 2}; + node->outputIndex = {3}; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_MatMul; + auto primitive = new schema::MatMulT; + primitive->transposeA = false; + primitive->transposeB = true; + node->primitive->value.value = primitive; + node->name = "MatMul1"; + meta_graph->nodes.emplace_back(std::move(node)); + } + { + auto node = std::make_unique(); + node->inputIndex = {3, 4}; + node->outputIndex = {5}; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_BiasAdd; + auto primitive = new schema::BiasAddT; + primitive->axis.push_back(0); + node->primitive->value.value = primitive; + node->name = "BiasAdd"; + meta_graph->nodes.emplace_back(std::move(node)); + } + { + auto node = std::make_unique(); + node->inputIndex = {5, 6}; + node->outputIndex = {14, 7}; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_SoftmaxCrossEntropy; + auto primitive = new schema::SoftmaxCrossEntropyT; + primitive->axis.push_back(0); + node->primitive->value.value = primitive; + node->name = "SoftmaxCrossEntropy"; + meta_graph->nodes.emplace_back(std::move(node)); + } + { + auto node = std::make_unique(); + node->inputIndex = {7}; + node->outputIndex = {8}; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_BiasGrad; + auto primitive = new schema::BiasGradT; + primitive->axis.push_back(0); + node->primitive->value.value = primitive; + node->name = "BiasGrad"; + meta_graph->nodes.emplace_back(std::move(node)); + } + { + auto node = std::make_unique(); + node->inputIndex = {7, 1}; + node->outputIndex = {9}; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_MatMul; + auto primitive = new schema::MatMulT; + primitive->transposeA = true; + primitive->transposeB = false; + node->primitive->value.value = primitive; + node->name = "MatMul2"; + meta_graph->nodes.emplace_back(std::move(node)); + } + { + auto node = std::make_unique(); + node->inputIndex = {2, 10, 11, 9, 12}; + node->outputIndex = {}; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_ApplyMomentum; + auto primitive = new schema::ApplyMomentumT; + node->primitive->value.value = primitive; + node->name = "Momentum"; + meta_graph->nodes.emplace_back(std::move(node)); + } + { + auto node = std::make_unique(); + node->inputIndex = {4, 13, 11, 8, 12}; + node->outputIndex = {}; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_ApplyMomentum; + auto primitive = new schema::ApplyMomentumT; + node->primitive->value.value = primitive; + node->name = "Momentum"; + meta_graph->nodes.emplace_back(std::move(node)); + } + meta_graph->inputIndex = {6, 0}; // XXX TODO why is it reverse? + meta_graph->outputIndex = {5, 14}; + const int NUM_OF_OUTPUTS = 2; + + auto input0 = std::make_unique(); + input0->nodeType = schema::NodeType::NodeType_ValueNode; + input0->format = schema::Format_NHWC; + input0->dataType = TypeId::kNumberTypeFloat32; + input0->dims = {BATCH_SIZE, FEATURE_SIZE}; + input0->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input0)); + // tensor 1 - relu + auto relu_out = std::make_unique(); + relu_out->nodeType = schema::NodeType::NodeType_Parameter; + relu_out->format = schema::Format_NHWC; + relu_out->dataType = TypeId::kNumberTypeFloat32; + relu_out->dims = {BATCH_SIZE, FEATURE_SIZE}; + relu_out->offset = -1; + meta_graph->allTensors.emplace_back(std::move(relu_out)); + // tensor 2 - matmul weights + auto weight = std::make_unique(); + weight->nodeType = schema::NodeType::NodeType_ValueNode; + weight->format = schema::Format_KHWC; + weight->dataType = TypeId::kNumberTypeFloat32; + weight->dims = {NUM_CLASSES, FEATURE_SIZE}; + size_t weight_size; + char *buf; + std::string weight_path = "./test_data/train/train_weight_10_1000.bin"; + ReadFile(weight_path.c_str(), &weight_size, &buf); + ASSERT_NE(nullptr, buf); + weight->data.resize(weight_size); + std::copy(buf, buf + weight_size, weight->data.data()); + meta_graph->allTensors.emplace_back(std::move(weight)); + // tensor 3 - matmul + auto input3 = std::make_unique(); + input3->nodeType = schema::NodeType::NodeType_Parameter; + input3->format = schema::Format_NHWC; + input3->dataType = TypeId::kNumberTypeFloat32; + input3->dims = {BATCH_SIZE, NUM_CLASSES}; + input3->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input3)); + // tensor 4 - fc bias + auto bias = std::make_unique(); + bias->nodeType = schema::NodeType::NodeType_ValueNode; + bias->format = schema::Format_NHWC; + bias->dataType = TypeId::kNumberTypeFloat32; + bias->dims = {NUM_CLASSES}; + bias->offset = -1; + std::string bias_path = "./test_data/train/train_bias_10.bin"; + size_t bias_size; + ReadFile(bias_path.c_str(), &bias_size, &buf); + ASSERT_NE(nullptr, buf); + bias->data.resize(bias_size); + std::copy(buf, buf + bias_size, bias->data.data()); + meta_graph->allTensors.emplace_back(std::move(bias)); + + // tensor 5 - bias_add + auto input5 = std::make_unique(); + input5->nodeType = schema::NodeType::NodeType_Parameter; + input5->format = schema::Format_NHWC; + input5->dataType = TypeId::kNumberTypeFloat32; + input5->dims = {BATCH_SIZE, NUM_CLASSES}; + input5->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input5)); + // tensor 6 - Label + { + auto label = std::make_unique(); + label->nodeType = schema::NodeType::NodeType_ValueNode; + label->format = schema::Format_NHWC; + label->dataType = TypeId::kNumberTypeInt32; + label->dims = {BATCH_SIZE}; + label->offset = -1; + label->data.resize(BATCH_SIZE * NUM_CLASSES * sizeof(float)); + int *data = reinterpret_cast(label->data.data()); + for (int i = 0; i < BATCH_SIZE; i++) + for (int j = 0; j < NUM_CLASSES; j++) *(data + i * NUM_CLASSES + j) = j; + meta_graph->allTensors.emplace_back(std::move(label)); + } + // tensor 7 - Softmaxentropy + auto input7 = std::make_unique(); + input7->nodeType = schema::NodeType::NodeType_Parameter; + input7->format = schema::Format_NHWC; + input7->dataType = TypeId::kNumberTypeFloat32; + input7->dims = {BATCH_SIZE, NUM_CLASSES}; + input7->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input7)); + // tensor 8 - biasGrad + auto input8 = std::make_unique(); + input8->nodeType = schema::NodeType::NodeType_Parameter; + input8->format = schema::Format_NHWC; + input8->dataType = TypeId::kNumberTypeFloat32; + input8->dims = {NUM_CLASSES}; + input8->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input8)); + // tensor 9 - matmul2 + auto input9 = std::make_unique(); + input9->nodeType = schema::NodeType::NodeType_Parameter; + input9->format = schema::Format_NHWC; + input9->dataType = TypeId::kNumberTypeFloat32; + input9->dims = {NUM_CLASSES, FEATURE_SIZE}; + input9->offset = -1; + meta_graph->allTensors.emplace_back(std::move(input9)); + // tensor 10 weights accumulate + auto input10 = std::make_unique(); + input10->nodeType = schema::NodeType::NodeType_ValueNode; + input10->format = schema::Format_NHWC; + input10->dataType = TypeId::kNumberTypeFloat32; + input10->dims = {NUM_CLASSES, FEATURE_SIZE}; + input10->offset = -1; + size_t input10_size = NUM_CLASSES * FEATURE_SIZE * sizeof(float); + input10->data.resize(input10_size); + std::fill(input10->data.data(), input10->data.data() + input10_size, 0.f); + meta_graph->allTensors.emplace_back(std::move(input10)); + // tensor 11 - lr + { + auto lr = std::make_unique(); + lr->nodeType = schema::NodeType::NodeType_ValueNode; + lr->format = schema::Format_NHWC; + lr->dataType = TypeId::kNumberTypeFloat32; + lr->dims = {1}; + lr->offset = -1; + lr->data.resize(sizeof(float)); + float *data = reinterpret_cast(lr->data.data()); + *data = 0.01f; + meta_graph->allTensors.emplace_back(std::move(lr)); + } + // tensor 12 - momentum + { + auto input12 = std::make_unique(); + input12->nodeType = schema::NodeType::NodeType_ValueNode; + input12->format = schema::Format_NHWC; + input12->dataType = TypeId::kNumberTypeFloat32; + input12->dims = {1}; + input12->offset = -1; + input12->data.resize(sizeof(float)); + float *data = reinterpret_cast(input12->data.data()); + *data = 0.f; + meta_graph->allTensors.emplace_back(std::move(input12)); + } + // tensor 13 - bias accumulate + auto input13 = std::make_unique(); + input13->nodeType = schema::NodeType::NodeType_ValueNode; + input13->format = schema::Format_NHWC; + input13->dataType = TypeId::kNumberTypeFloat32; + input13->dims = {NUM_CLASSES}; + input13->offset = -1; + size_t input13_size = NUM_CLASSES * sizeof(float); + input13->data.resize(input13_size); + std::fill(input13->data.data(), input13->data.data() + input13_size, 0.f); + meta_graph->allTensors.emplace_back(std::move(input13)); + + // tensor 14 - loss + { + auto loss14 = std::make_unique(); + loss14->nodeType = schema::NodeType::NodeType_ValueNode; + loss14->format = schema::Format_NHWC; + loss14->dataType = TypeId::kNumberTypeFloat32; + loss14->dims = {1}; + loss14->offset = -1; + loss14->data.resize(sizeof(float)); + float *data = reinterpret_cast(loss14->data.data()); + *data = 0.0f; + meta_graph->allTensors.emplace_back(std::move(loss14)); + } + + //================================================================ + buf = nullptr; + + flatbuffers::FlatBufferBuilder builder(1024); + auto offset = schema::MetaGraph::Pack(builder, meta_graph.get()); + builder.Finish(offset); + size_t size = builder.GetSize(); + const char *content = reinterpret_cast(builder.GetBufferPointer()); + std::cout << "build fb size= " << size << "\n"; + +#if 0 // EXPORT_FILE + std::string path = std::string("hcdemo_train.fb"); + std::ofstream ofs(path); + ASSERT_EQ(true, ofs.good()); + ASSERT_EQ(true, ofs.is_open()); + + ofs.seekp(0, std::ios::beg); + ofs.write(content, size); + ofs.close(); +#endif + + auto model = lite::Model::Import(content, size); + ASSERT_NE(nullptr, model); + meta_graph.reset(); + content = nullptr; + auto context = new lite::Context; + context->device_ctx_.type = lite::DT_CPU; + context->cpu_bind_mode_ = lite::NO_BIND; + context->thread_num_ = 1; + auto session = new session::TrainSession(); + ASSERT_NE(nullptr, session); + session->Init(context); + auto ret = session->CompileGraph(model); + ASSERT_EQ(lite::RET_OK, ret); + session->train(); + + auto inputs = session->GetInputs(); + ASSERT_EQ(inputs.size(), 2); + auto inTensor = inputs.at(0); + ASSERT_NE(nullptr, inTensor); + auto data = inTensor->MutableData(); + //=================================================== + size_t input_size; + std::string input_path = "./test_data/train/train_input_32_1000.bin"; + ReadFile(input_path.c_str(), &input_size, &buf); + ASSERT_NE(nullptr, buf); + auto input_data = reinterpret_cast(buf); + ASSERT_NE(nullptr, input_data); + //=================================================== + ASSERT_EQ(input_size, inTensor->Size()); + memcpy(data, input_data, input_size); + + auto labelTensor = inputs.at(1); + ASSERT_NE(nullptr, labelTensor); + ASSERT_EQ(BATCH_SIZE, labelTensor->ElementsNum()); + auto labels = reinterpret_cast(labelTensor->MutableData()); + for (int i = 0; i < BATCH_SIZE; i++) labels[i] = (i * 97) % NUM_CLASSES; + + ret = session->RunGraph(); + ASSERT_EQ(lite::RET_OK, ret); + auto outputs = session->GetOutputsByName("BiasAdd"); + ASSERT_EQ(outputs.size(), 1); + auto outTensor = (outputs.at(0)); + ASSERT_NE(nullptr, outTensor); + ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type()); + auto *outData = reinterpret_cast(outTensor->MutableData()); + ASSERT_NE(nullptr, outData); + std::cout << "========================dW=====================" << std::endl; + for (int i = 0; i < 20; i++) { + std::cout << outData[i] << ", "; + } + std::cout << std::endl; + ret = session->RunGraph(); + outputs = session->GetOutputsByName("BiasAdd"); + ASSERT_EQ(outputs.size(), 1); + outTensor = (outputs.at(0)); + ASSERT_NE(nullptr, outTensor); + // ASSERT_EQ(28 * 28 * 32, outTensor->ElementsNum()); + ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type()); + outData = reinterpret_cast(outTensor->MutableData()); + ASSERT_NE(nullptr, outData); + std::cout << "========================dW=====================" << std::endl; + for (int i = 0; i < 20; i++) { + std::cout << outData[i] << ", "; + } +//=================================================== +#if 0 + size_t output_size; + std::string output_path = "./convfp32_out_1_28_28_32.bin"; + buf = mindspore::lite::ReadFile(output_path.c_str(), &output_size); + ASSERT_NE(nullptr, buf); + auto output_data = reinterpret_cast(buf); + ASSERT_NE(nullptr, output_data); + //=================================================== + ASSERT_EQ(output_size, runOutput->Size()); + for (size_t i = 0; i < runOutput->ElementsNum(); i++) { + ASSERT_EQ(output_data[i], outData[i]); + } +#endif + MS_LOG(INFO) << "Passed"; +} + +int32_t fileIterator(mindspore::session::TrainSession *session, const std::string &path, + std::function cb) { + int32_t res = 0; + if (auto dir = opendir(path.c_str())) { + while (auto f = readdir(dir)) { + if (!f->d_name || f->d_name[0] == '.') continue; + if (f->d_type == DT_DIR) fileIterator(session, path + f->d_name + "/", cb); + + if (f->d_type == DT_REG) + res |= cb(session, path + f->d_name); + } + closedir(dir); + } + return res; +} +#if 0 +void replaceExt(const std::string &src, std::string *dst) { + dst = &std::move(src.substr(0, src.find_last_of('.')) + ".emb"); +} +#endif +int32_t runEffNet(mindspore::session::TrainSession *session, const std::string &in, const std::string &out) { + // setup input + auto inputs = session->GetInputs(); + // ASSERT_EQ(inputs.size(), 1); + auto inTensor = inputs.at(0); + // ASSERT_NE(nullptr, inTensor); + float *data = reinterpret_cast(inTensor->MutableData()); + + size_t input_size; + float *in_buf = reinterpret_cast(lite::ReadFile(in.c_str(), &input_size)); + // ASSERT_NE(nullptr, data); + auto input_data = reinterpret_cast(in_buf); + // ASSERT_EQ(input_size, inTensor->Size()); + std::copy(input_data, input_data + inTensor->ElementsNum(), data); + + // execute network + session->RunGraph(); + + // compare outputs + auto outputs = session->GetOutputs(); + auto output = ((outputs.begin())->second); + float *output_data = reinterpret_cast(output.at(0)->MutableData()); + + return mindspore::lite::CompareRelativeOutput(output_data, out.c_str()); +} + +TEST_F(NetworkTest, efficient_net) { + const int NUM_OF_INPUTS = 1; + char *buf = nullptr; + size_t net_size = 0; + std::string net = "./test_data/nets/efficientnet_b0_f.ms"; + ReadFile(net.c_str(), &net_size, &buf); + auto model = lite::Model::Import(buf, net_size); + auto context = new lite::Context; + context->device_ctx_.type = lite::DT_CPU; + context->cpu_bind_mode_ = lite::NO_BIND; + context->thread_num_ = 1; + + + auto session = new mindspore::session::TrainSession(); + ASSERT_NE(session, nullptr); + auto ret = session->Init(context); + ASSERT_EQ(lite::RET_OK, ret); + ret = session->CompileGraph(model); + ASSERT_EQ(lite::RET_OK, ret); + session->eval(); + +#if 0 + std::string path = "/opt/share/MiniBinEmbDataset/"; + auto res = fileIterator(session, path, [](mindspore::session::TrainSession *session, const std::string &in) { + int32_t res = 0; + if (in.find(".bin") != std::string::npos) { + std::string out; + replaceExt(in, out); + res = runEffNet(session, in, out); + std::cout << "input file: " << in << (res ? " Fail" : " Pass") << std::endl; + } + return res; + }); +#else + std::string in = "./test_data/nets/effNet_input_x_1_3_224_224.bin"; + std::string out = "./test_data/nets/effNet_output_y_1_1000.bin"; + auto res = runEffNet(session, in, out); +#endif + // auto inputs = session->GetInputs(); + // ASSERT_EQ(inputs.size(), NUM_OF_INPUTS); + // auto inTensor = inputs.at(0); + // ASSERT_NE(nullptr, inTensor); + // float *data = reinterpret_cast(inTensor->MutableData()); + + // // fill input + // std::string input_path = "./test_data/nets/effNet_input_x_1_3_224_224.bin"; + // // std::string input_path = "/opt/share/MiniBinEmbDataset/2_pet/n02099601_3111.bin"; + // size_t input_size; + // char *in_buf = nullptr; + // ReadFile(input_path.c_str(), &input_size, &in_buf); + // ASSERT_NE(nullptr, data); + // auto input_data = reinterpret_cast(in_buf); + // ASSERT_EQ(input_size, inTensor->Size()); + // std::copy(input_data, input_data+inTensor->ElementsNum(), data); + + // // execute network + // ret = session->RunGraph(); + + // // compare outputs + // std::string output_path = "./test_data/nets/effNet_output_y_1_1000.bin"; + // // std::string output_path = "/opt/share/MiniBinEmbDataset/2_pet/n02099601_3111.emb"; + // auto outputs = session->GetOutputs(); + // auto output = ((outputs.begin())->second); + // float* output_data = reinterpret_cast(output.at(0)->MutableData()); + // int res = lite::CompareRelativeOutput(output_data, output_path); + ASSERT_EQ(res, 0); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc index 22aabd7757..546f4580c9 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc @@ -22,6 +22,7 @@ #include "mindspore/lite/src/kernel_registry.h" #include "src/common/utils.h" #include "src/common/file_utils.h" +#include "src/common/file_utils_ext.h" #include "src/runtime/kernel/arm/fp32_grad/pooling_grad.h" #include "nnacl/fp32_grad/pooling_grad.h" @@ -60,6 +61,7 @@ TEST_F(TestPoolingGradFp32, AvgPoolingGradFp32) { auto pooling_param = new PoolingParameter(); InitPoolingParamFP32(pooling_param); pooling_param->output_channel_ = 3; + pooling_param->pool_mode_ = PoolMode_AvgPool; // runtime part printf("Calculating runtime cost...\n"); @@ -95,7 +97,7 @@ TEST_F(TestPoolingGradFp32, AvgPoolingGradFp32) { std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_1_28_28_3.bin"; lite::CompareOutput(output_data, output_path); - delete input_data; + delete[] input_data; delete[] output_data; delete pooling_param; MS_LOG(INFO) << "TestAvgPoolingGradFp32 passed"; @@ -122,10 +124,10 @@ TEST_F(TestPoolingGradFp32, AvgPoolingKernelGradFp32) { dy_tensor.SetData(input_data); std::string input1_path = "./test_data/pooling/avgpoolgradfp32_1_x_1_28_28_3.bin"; - input_data = reinterpret_cast(mindspore::lite::ReadFile(input1_path.c_str(), &input_size)); + auto input1_data = reinterpret_cast(mindspore::lite::ReadFile(input1_path.c_str(), &input_size)); std::vector dim_x({1, 28, 28, 3}); lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); - x_tensor.SetData(input_data); + x_tensor.SetData(input1_data); std::vector inputs = {&dy_tensor, &x_tensor}; @@ -150,12 +152,205 @@ TEST_F(TestPoolingGradFp32, AvgPoolingKernelGradFp32) { std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_1_28_28_3.bin"; lite::CompareOutput(output_data, output_path); - // delete input_data; - // delete[] output_data; - delete pooling_param; + delete[] input_data; + delete[] input1_data; + delete[] output_data; + dx_tensor.SetData(nullptr); + x_tensor.SetData(nullptr); + dy_tensor.SetData(nullptr); + // delete pooling_param; + delete kernel_obj; MS_LOG(INFO) << "TestAvgPoolingGradFp32 passed"; } +TEST_F(TestPoolingGradFp32, AvgPoolingBatchGradFp32) { + // prepare stage + auto pooling_param = new PoolingParameter(); + InitPoolingParamFP32(pooling_param); + + pooling_param->output_channel_ = 3; + pooling_param->input_batch_ = 3; + pooling_param->output_batch_ = 3; + + // runtime part + printf("Calculating runtime cost...\n"); + // uint64_t time_avg = 0; + size_t output_data_size = + pooling_param->output_batch_ * pooling_param->output_channel_ * pooling_param->input_h_ * pooling_param->input_w_; + + size_t input_size; + std::string input_path = "./test_data/pooling/avgpoolgradfp32_1_dy_3_28_28_3.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + std::vector dim_dy({1, 28, 28, 3}); + lite::tensor::Tensor dy_tensor(TypeId::kNumberTypeFloat32, dim_dy); + dy_tensor.SetData(input_data); + + std::string input1_path = "./test_data/pooling/avgpoolgradfp32_1_x_3_28_28_3.bin"; + auto input1_data = reinterpret_cast(mindspore::lite::ReadFile(input1_path.c_str(), &input_size)); + std::vector dim_x({1, 28, 28, 3}); + lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); + x_tensor.SetData(input1_data); + + std::vector inputs = {&dy_tensor, &x_tensor}; + + auto output_data = new float[output_data_size]; + std::vector dim_dx({1, 28, 28, 3}); + lite::tensor::Tensor dx_tensor(TypeId::kNumberTypeFloat32, dim_dx); + dx_tensor.SetData(output_data); + std::vector outputs = {&dx_tensor}; + + kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(pooling_param), NULL, desc, nullptr); + + kernel_obj->Run(); + + printf("==================output data=================\n"); + for (int i = 0; i < 20; i++) { + std::cout << output_data[i] << " ,"; + } + std::cout << std::endl; + std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_3_28_28_3.bin"; + lite::CompareOutput(output_data, output_path); + + delete[] input_data; + delete[] input1_data; + delete[] output_data; + dx_tensor.SetData(nullptr); + x_tensor.SetData(nullptr); + dy_tensor.SetData(nullptr); + // delete pooling_param; + delete kernel_obj; + MS_LOG(INFO) << "TestAvgPoolingGradBatchFp32 passed"; +} + +TEST_F(TestPoolingGradFp32, AvgPoolGradStride2Fp32) { + // prepare stage + // input size will be equal to the original size of x, output size will be the output size as in forward + auto pool = new PoolingParameter(); + InitPoolingParamFP32(pool); + pool->output_channel_ = 3; + pool->pool_mode_ = PoolMode_AvgPool; + pool->input_batch_ = 3; + pool->output_batch_ = 3; + pool->output_h_ = 14; + pool->output_w_ = 14; + pool->stride_h_ = 2; + pool->stride_w_ = 2; + + size_t input_size; + size_t y_data_size = pool->output_batch_ * pool->output_channel_ * pool->input_h_ * pool->input_w_; + + auto x_data = reinterpret_cast( + mindspore::lite::ReadFile("./test_data/pooling/avgpoolgradfp32_s2_x_3_28_28_3.bin", &input_size)); + std::vector dim_x({pool->output_batch_, pool->input_h_, pool->input_w_, pool->input_channel_}); + lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); + x_tensor.SetData(x_data); + + auto yt_data = reinterpret_cast( + mindspore::lite::ReadFile("./test_data/pooling/avgpoolgradfp32_s2_dy_3_28_28_3.bin", &input_size)); + std::vector dim_y({pool->output_batch_, pool->output_h_, pool->output_w_, pool->output_channel_}); + lite::tensor::Tensor yt_tensor(TypeId::kNumberTypeFloat32, dim_y); + yt_tensor.SetData(yt_data); + + auto out_data = new float[y_data_size]; + lite::tensor::Tensor out_tensor(TypeId::kNumberTypeFloat32, dim_x); + out_tensor.SetData(out_data); + + std::vector inputs = {&yt_tensor, &x_tensor}; + std::vector outputs = {&out_tensor}; + // ---------------------------------------- + kernel::KernelKey pool_desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; + auto pool_creator = lite::KernelRegistry::GetInstance()->GetCreator(pool_desc); + auto kernel = pool_creator(inputs, outputs, reinterpret_cast(pool), NULL, pool_desc, nullptr); + + kernel->Init(); + + auto time_start = mindspore::lite::GetTimeUs(); + kernel->Run(); + auto time_end = mindspore::lite::GetTimeUs(); + printf("single thread running time : %ld ms\n", time_end - time_start); + + std::string output_path = "./test_data/pooling/avgpoolgradfp32_s2_dx_3_28_28_3.bin"; + auto res = lite::CompareRelativeOutput(out_data, output_path); + + EXPECT_EQ(res, 0); + + delete[] x_data; + delete[] yt_data; + // delete[] out_data; + // delete conv_param; + x_tensor.SetData(nullptr); + yt_tensor.SetData(nullptr); + out_tensor.SetData(nullptr); + delete kernel; + MS_LOG(INFO) << "AvgPoolGradStride2Fp32 Filter Grad passed"; +} + +TEST_F(TestPoolingGradFp32, AvgPoolGradStride3Fp32) { + // prepare stage + // input size will be equal to the original size of x, output size will be the output size as in forward + auto pool = new PoolingParameter(); + InitPoolingParamFP32(pool); + pool->output_channel_ = 3; + pool->pool_mode_ = PoolMode_AvgPool; + pool->input_batch_ = 3; + pool->output_batch_ = 3; + pool->output_h_ = 10; + pool->output_w_ = 10; + pool->stride_h_ = 3; + pool->stride_w_ = 3; + + size_t input_size; + size_t y_data_size = pool->output_batch_ * pool->output_channel_ * pool->input_h_ * pool->input_w_; + + auto x_data = reinterpret_cast( + mindspore::lite::ReadFile("./test_data/pooling/avgpoolgradfp32_s3_x_3_28_28_3.bin", &input_size)); + std::vector dim_x({pool->output_batch_, pool->input_h_, pool->input_w_, pool->input_channel_}); + lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); + x_tensor.SetData(x_data); + + auto yt_data = reinterpret_cast( + mindspore::lite::ReadFile("./test_data/pooling/avgpoolgradfp32_s3_dy_3_28_28_3.bin", &input_size)); + std::vector dim_y({pool->output_batch_, pool->output_h_, pool->output_w_, pool->output_channel_}); + lite::tensor::Tensor yt_tensor(TypeId::kNumberTypeFloat32, dim_y); + yt_tensor.SetData(yt_data); + + auto out_data = new float[y_data_size]; + lite::tensor::Tensor out_tensor(TypeId::kNumberTypeFloat32, dim_x); + out_tensor.SetData(out_data); + + std::vector inputs = {&yt_tensor, &x_tensor}; + std::vector outputs = {&out_tensor}; + // ---------------------------------------- + kernel::KernelKey pool_desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; + auto pool_creator = lite::KernelRegistry::GetInstance()->GetCreator(pool_desc); + auto kernel = pool_creator(inputs, outputs, reinterpret_cast(pool), NULL, pool_desc, nullptr); + + kernel->Init(); + + auto time_start = mindspore::lite::GetTimeUs(); + kernel->Run(); + auto time_end = mindspore::lite::GetTimeUs(); + printf("single thread running time : %ld ms\n", time_end - time_start); + + std::string output_path = "./test_data/pooling/avgpoolgradfp32_s3_dx_3_28_28_3.bin"; + auto res = lite::CompareRelativeOutput(out_data, output_path); + + EXPECT_EQ(res, 0); + + delete[] x_data; + delete[] yt_data; + // delete[] out_data; + // delete conv_param; + x_tensor.SetData(nullptr); + yt_tensor.SetData(nullptr); + out_tensor.SetData(nullptr); + delete kernel; + MS_LOG(INFO) << "AvgPoolGradStride3Fp32 Filter Grad passed"; +} + TEST_F(TestPoolingGradFp32, MaxPoolingGradFp32) { // prepare stage auto pooling_param = new PoolingParameter(); @@ -169,26 +364,25 @@ TEST_F(TestPoolingGradFp32, MaxPoolingGradFp32) { pooling_param->output_batch_ * pooling_param->output_channel_ * pooling_param->output_h_ * pooling_param->output_w_; size_t input_size; - std::string i_path = "./test_data/pooling/maxpoolgradfp32_1_i_1_28_28_3.bin"; - auto ill_data = reinterpret_cast(mindspore::lite::ReadFile(i_path.c_str(), &input_size)); - auto i_data = new int[output_data_size]; - for (uint32_t i = 0; i < output_data_size; i++) { - i_data[i] = static_cast(ill_data[i]); - } + std::string i_path = "./test_data/pooling/maxpoolgradfp32_1_x_1_28_28_3.bin"; + auto in_data = reinterpret_cast(mindspore::lite::ReadFile(i_path.c_str(), &input_size)); std::string dy_path = "./test_data/pooling/maxpoolgradfp32_1_dy_1_28_28_3.bin"; auto dy_data = reinterpret_cast(mindspore::lite::ReadFile(dy_path.c_str(), &input_size)); + std::string dx_path = "./test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin"; + auto dx_data = reinterpret_cast(mindspore::lite::ReadFile(dx_path.c_str(), &input_size)); + auto output_data = new float[output_data_size]; // warm up loop for (int i = 0; i < 3; i++) { - MaxPoolingGrad(dy_data, i_data, output_data, pooling_param); + MaxPoolingGrad(in_data, dx_data, dy_data, output_data, pooling_param); } int loop_count = 100; auto time_start = mindspore::lite::GetTimeUs(); for (int i = 0; i < loop_count; i++) { - MaxPoolingGrad(dy_data, i_data, output_data, pooling_param); + MaxPoolingGrad(in_data, dx_data, dy_data, output_data, pooling_param); } auto time_end = mindspore::lite::GetTimeUs(); auto cost = time_end - time_start; @@ -200,11 +394,13 @@ TEST_F(TestPoolingGradFp32, MaxPoolingGradFp32) { std::cout << output_data[i] << " ,"; } std::cout << std::endl; - std::string output_path = "./test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin"; + std::string output_path = "./test_data/pooling/maxpoolgradfp32_1_xgrad_1_28_28_3.bin"; lite::CompareOutput(output_data, output_path); - // delete input_data; + delete[] in_data; delete pooling_param; + delete[] dy_data; + delete[] dx_data; delete[] output_data; MS_LOG(INFO) << "TestMaxPoolingGradFp32 passed"; } @@ -326,4 +522,216 @@ TEST_F(TestPoolingGradFp32, MaxPoolingKernelGradFp32) { MS_LOG(INFO) << "TestMaxPoolingKernelGradFp32 passed"; } #endif // if 0 before MaxPoolingKernelGradFp32 + +TEST_F(TestPoolingGradFp32, MaxPoolGradBatchFp32) { + // prepare stage + // input size will be equal to the original size of x, output size will be the output size as in forward + auto maxpool = new PoolingParameter(); + InitPoolingParamFP32(maxpool); + maxpool->output_channel_ = 3; + maxpool->pool_mode_ = PoolMode_MaxPool; + maxpool->input_batch_ = 3; + maxpool->output_batch_ = 3; + + size_t input_size; + size_t y_data_size = maxpool->output_batch_ * maxpool->output_channel_ * maxpool->input_h_ * maxpool->input_w_; + + auto x_data = reinterpret_cast( + mindspore::lite::ReadFile("./test_data/pooling/maxpoolgradfp32_1_x_3_28_28_3.bin", &input_size)); + std::vector dim_x({3, 28, 28, 3}); + lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); + x_tensor.SetData(x_data); + + auto y_data = reinterpret_cast( + mindspore::lite::ReadFile("./test_data/pooling/maxpoolgradfp32_1_dx_3_28_28_3.bin", &input_size)); + std::vector dim_y({3, 28, 28, 3}); + lite::tensor::Tensor y_tensor(TypeId::kNumberTypeFloat32, dim_y); + y_tensor.SetData(y_data); + + auto yt_data = reinterpret_cast( + mindspore::lite::ReadFile("./test_data/pooling/maxpoolgradfp32_1_dy_3_28_28_3.bin", &input_size)); + lite::tensor::Tensor yt_tensor(TypeId::kNumberTypeFloat32, dim_y); + yt_tensor.SetData(yt_data); + + auto out_data = new float[y_data_size]; + lite::tensor::Tensor out_tensor(TypeId::kNumberTypeFloat32, dim_x); + out_tensor.SetData(out_data); + + std::vector maxpool_inputs = {&x_tensor, &y_tensor, &yt_tensor}; + std::vector maxpool_outputs = {&out_tensor}; + // ---------------------------------------- + kernel::KernelKey maxpool_desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; + auto maxpool_creator = lite::KernelRegistry::GetInstance()->GetCreator(maxpool_desc); + auto kernel = maxpool_creator(maxpool_inputs, maxpool_outputs, reinterpret_cast(maxpool), NULL, + maxpool_desc, nullptr); + + kernel->Init(); + + auto time_start = mindspore::lite::GetTimeUs(); + kernel->Run(); + auto time_end = mindspore::lite::GetTimeUs(); + printf("single thread running time : %ld ms\n", time_end - time_start); + + std::string output_path = "./test_data/pooling/maxpoolgradfp32_1_xgrad_3_28_28_3.bin"; + auto res = lite::CompareRelativeOutput(out_data, output_path); + + EXPECT_EQ(res, 0); + + delete[] x_data; + delete[] y_data; + delete[] yt_data; + // delete[] out_data; + // delete conv_param; + x_tensor.SetData(nullptr); + y_tensor.SetData(nullptr); + yt_tensor.SetData(nullptr); + out_tensor.SetData(nullptr); + delete kernel; + MS_LOG(INFO) << "MaxPoolGradBatchFp32 Filter Grad passed"; +} + +TEST_F(TestPoolingGradFp32, MaxPoolGradStride2Fp32) { + // prepare stage + // input size will be equal to the original size of x, output size will be the output size as in forward + auto maxpool = new PoolingParameter(); + InitPoolingParamFP32(maxpool); + maxpool->output_channel_ = 3; + maxpool->input_channel_ = 3; + maxpool->pool_mode_ = PoolMode_MaxPool; + maxpool->input_batch_ = 3; + maxpool->output_batch_ = 3; + maxpool->output_h_ = 14; + maxpool->output_w_ = 14; + maxpool->stride_h_ = 2; + maxpool->stride_w_ = 2; + + size_t input_size; + size_t y_data_size = maxpool->output_batch_ * maxpool->output_channel_ * maxpool->input_h_ * maxpool->input_w_; + + auto x_data = reinterpret_cast( + mindspore::lite::ReadFile("./test_data/pooling/maxpoolgradfp32_s2_x_3_28_28_3.bin", &input_size)); + std::vector dim_x({maxpool->output_batch_, maxpool->input_h_, maxpool->input_w_, maxpool->input_channel_}); + lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); + x_tensor.SetData(x_data); + + auto y_data = reinterpret_cast( + mindspore::lite::ReadFile("./test_data/pooling/maxpoolgradfp32_s2_dx_3_28_28_3.bin", &input_size)); + std::vector dim_y({maxpool->output_batch_, maxpool->output_h_, maxpool->output_w_, maxpool->output_channel_}); + lite::tensor::Tensor y_tensor(TypeId::kNumberTypeFloat32, dim_y); + y_tensor.SetData(y_data); + + auto yt_data = reinterpret_cast( + mindspore::lite::ReadFile("./test_data/pooling/maxpoolgradfp32_s2_dy_3_28_28_3.bin", &input_size)); + lite::tensor::Tensor yt_tensor(TypeId::kNumberTypeFloat32, dim_y); + yt_tensor.SetData(yt_data); + + auto out_data = new float[y_data_size]; + lite::tensor::Tensor out_tensor(TypeId::kNumberTypeFloat32, dim_x); + out_tensor.SetData(out_data); + + std::vector maxpool_inputs = {&x_tensor, &y_tensor, &yt_tensor}; + std::vector maxpool_outputs = {&out_tensor}; + // ---------------------------------------- + kernel::KernelKey maxpool_desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; + auto maxpool_creator = lite::KernelRegistry::GetInstance()->GetCreator(maxpool_desc); + auto kernel = maxpool_creator(maxpool_inputs, maxpool_outputs, reinterpret_cast(maxpool), NULL, + maxpool_desc, nullptr); + + kernel->Init(); + + auto time_start = mindspore::lite::GetTimeUs(); + kernel->Run(); + auto time_end = mindspore::lite::GetTimeUs(); + printf("single thread running time : %ld ms\n", time_end - time_start); + + std::string output_path = "./test_data/pooling/maxpoolgradfp32_s2_xgrad_3_28_28_3.bin"; + auto res = lite::CompareRelativeOutput(out_data, output_path); + + EXPECT_EQ(res, 0); + + delete[] x_data; + delete[] y_data; + delete[] yt_data; + // delete[] out_data; + // delete conv_param; + x_tensor.SetData(nullptr); + y_tensor.SetData(nullptr); + yt_tensor.SetData(nullptr); + out_tensor.SetData(nullptr); + delete kernel; + MS_LOG(INFO) << "MaxPoolGradStride2Fp32 Filter Grad passed"; +} + +TEST_F(TestPoolingGradFp32, MaxPoolGradStride3Fp32) { + // prepare stage + // input size will be equal to the original size of x, output size will be the output size as in forward + auto maxpool = new PoolingParameter(); + InitPoolingParamFP32(maxpool); + maxpool->output_channel_ = 3; + maxpool->input_channel_ = 3; + maxpool->pool_mode_ = PoolMode_MaxPool; + maxpool->input_batch_ = 3; + maxpool->output_batch_ = 3; + maxpool->output_h_ = 10; + maxpool->output_w_ = 10; + maxpool->stride_h_ = 3; + maxpool->stride_w_ = 3; + + size_t input_size; + size_t y_data_size = maxpool->output_batch_ * maxpool->output_channel_ * maxpool->input_h_ * maxpool->input_w_; + + auto x_data = reinterpret_cast( + mindspore::lite::ReadFile("./test_data/pooling/maxpoolgradfp32_s3_x_3_28_28_3.bin", &input_size)); + std::vector dim_x({maxpool->output_batch_, maxpool->input_h_, maxpool->input_w_, maxpool->input_channel_}); + lite::tensor::Tensor x_tensor(TypeId::kNumberTypeFloat32, dim_x); + x_tensor.SetData(x_data); + + auto y_data = reinterpret_cast( + mindspore::lite::ReadFile("./test_data/pooling/maxpoolgradfp32_s3_dx_3_28_28_3.bin", &input_size)); + std::vector dim_y({maxpool->output_batch_, maxpool->output_h_, maxpool->output_w_, maxpool->output_channel_}); + lite::tensor::Tensor y_tensor(TypeId::kNumberTypeFloat32, dim_y); + y_tensor.SetData(y_data); + + auto yt_data = reinterpret_cast( + mindspore::lite::ReadFile("./test_data/pooling/maxpoolgradfp32_s3_dy_3_28_28_3.bin", &input_size)); + lite::tensor::Tensor yt_tensor(TypeId::kNumberTypeFloat32, dim_y); + yt_tensor.SetData(yt_data); + + auto out_data = new float[y_data_size]; + lite::tensor::Tensor out_tensor(TypeId::kNumberTypeFloat32, dim_x); + out_tensor.SetData(out_data); + + std::vector maxpool_inputs = {&x_tensor, &y_tensor, &yt_tensor}; + std::vector maxpool_outputs = {&out_tensor}; + // ---------------------------------------- + kernel::KernelKey maxpool_desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; + auto maxpool_creator = lite::KernelRegistry::GetInstance()->GetCreator(maxpool_desc); + auto kernel = maxpool_creator(maxpool_inputs, maxpool_outputs, reinterpret_cast(maxpool), NULL, + maxpool_desc, nullptr); + + kernel->Init(); + + auto time_start = mindspore::lite::GetTimeUs(); + kernel->Run(); + auto time_end = mindspore::lite::GetTimeUs(); + printf("single thread running time : %ld ms\n", time_end - time_start); + + std::string output_path = "./test_data/pooling/maxpoolgradfp32_s3_xgrad_3_28_28_3.bin"; + auto res = lite::CompareRelativeOutput(out_data, output_path); + + EXPECT_EQ(res, 0); + + delete[] x_data; + delete[] y_data; + delete[] yt_data; + // delete[] out_data; + // delete conv_param; + x_tensor.SetData(nullptr); + y_tensor.SetData(nullptr); + yt_tensor.SetData(nullptr); + out_tensor.SetData(nullptr); + delete kernel; + MS_LOG(INFO) << "MaxPoolGradStride3Fp32 Filter Grad passed"; +} + } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc index eee16499fe..d3bc737393 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc @@ -40,7 +40,7 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { y_tensor.SetData(input_data); std::string label_path = "./test_data/operators/sce_fp32_1_l_6.bin"; - auto ll_labels = reinterpret_cast(mindspore::lite::ReadFile(label_path.c_str(), &input_size)); + auto ll_labels = reinterpret_cast(mindspore::lite::ReadFile(label_path.c_str(), &input_size)); auto labels = new int[6]; for (int i = 0; i < 6; i++) labels[i] = static_cast(ll_labels[i]); @@ -57,7 +57,7 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { auto grad = new float[24]; lite::tensor::Tensor grad_tensor(TypeId::kNumberTypeFloat32, dim_y); grad_tensor.SetData(grad); - std::vector outputs = {&grad_tensor, &loss_tensor}; + std::vector outputs = {&loss_tensor, &grad_tensor}; kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SoftmaxCrossEntropy}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/dy_2_4_5_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/dy_2_4_5_3.bin new file mode 100644 index 0000000000..2ccfc68d73 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/dy_2_4_5_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/input_x_2_4_5_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/input_x_2_4_5_3.bin new file mode 100644 index 0000000000..a194c59c0a --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/input_x_2_4_5_3.bin @@ -0,0 +1,2 @@ +V_?Kϧ࿅>J?/="m?Luj@!U$?f=?e[?Wھ m ? eO?}4?B?7E :?JͿ̬> ? ~?ϫN1?> HV|ʾ={IU?xvW>[$?]4Bu 4@+?z>uB?=|e >M>>?}0?> @=">: @<>+R +b6.?i?v?`j6R~]?JU6sG?M% ?h>ȿ G½?>ӓ'6?@2/VK5T>X]?[?v_ؿj?p?\l?.l=b? \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/output_dbias_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/output_dbias_3.bin new file mode 100644 index 0000000000..ac2915f01f Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/output_dbias_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/output_dscale_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/output_dscale_3.bin new file mode 100644 index 0000000000..cf8cfa4b05 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/output_dscale_3.bin @@ -0,0 +1 @@ +BBB \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/save_mean_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/save_mean_3.bin new file mode 100644 index 0000000000..a3cd501da7 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/save_mean_3.bin @@ -0,0 +1 @@ +:c2;@e< \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/save_var_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/save_var_3.bin new file mode 100644 index 0000000000..aa2b0f6ede --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/save_var_3.bin @@ -0,0 +1 @@ +=}?M?Z|? \ No newline at end of file diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/scale_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/scale_3.bin new file mode 100644 index 0000000000..c146119e96 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/bngrad/scale_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/conv1x1fp32_output1_nhwc.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/conv1x1fp32_output1_nhwc.bin index 66813ca607..6c1da39c17 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/conv1x1fp32_output1_nhwc.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/conv1x1fp32_output1_nhwc.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_32_3_3_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_32_3_3_3.bin index e69de29bb2..230d3a0367 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_32_3_3_3.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_32_3_3_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_g3_18_3_3_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_g3_18_3_3_3.bin index e69de29bb2..1565337c09 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_g3_18_3_3_3.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_g3_18_3_3_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_g3_d2_18_3_3_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_g3_d2_18_3_3_3.bin index e69de29bb2..71f3b26608 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_g3_d2_18_3_3_3.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dw_g3_d2_18_3_3_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_1_28_28_3.bin index e69de29bb2..f7dd08a76d 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_1_28_28_3.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_g3_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_g3_1_28_28_3.bin index e69de29bb2..560b8f1283 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_g3_1_28_28_3.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_g3_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_g3_d2_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_g3_d2_1_28_28_3.bin index e69de29bb2..299fd68961 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_g3_d2_1_28_28_3.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dx_g3_d2_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_1_28_28_32.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_1_28_28_32.bin index e69de29bb2..bb6772de40 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_1_28_28_32.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_1_28_28_32.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_g3_1_28_28_18.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_g3_1_28_28_18.bin index e69de29bb2..d4fde111be 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_g3_1_28_28_18.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_g3_1_28_28_18.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_g3_d2_1_26_26_18.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_g3_d2_1_26_26_18.bin index e69de29bb2..d8483212b8 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_g3_d2_1_26_26_18.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_dy_g3_d2_1_26_26_18.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_w_32_3_3_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_w_32_3_3_3.bin index e69de29bb2..62e3cf04e3 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_w_32_3_3_3.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_w_32_3_3_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_w_g3_d2_18_3_3_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_w_g3_d2_18_3_3_3.bin index e69de29bb2..b2cc35ffd3 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_w_g3_d2_18_3_3_3.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_w_g3_d2_18_3_3_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_1_28_28_3.bin index e69de29bb2..52548830ed 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_1_28_28_3.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_g3_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_g3_1_28_28_3.bin index e69de29bb2..e9cb65bae4 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_g3_1_28_28_3.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_g3_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_g3_d2_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_g3_d2_1_28_28_3.bin index e69de29bb2..62aba80e0a 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_g3_d2_1_28_28_3.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_x_g3_d2_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_y_g3_d2_1_26_26_18.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_y_g3_d2_1_26_26_18.bin index e69de29bb2..8d34666024 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_y_g3_d2_1_26_26_18.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/conv/convfp32_y_g3_d2_1_26_26_18.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effNet_input_x_1_3_224_224.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effNet_input_x_1_3_224_224.bin new file mode 100755 index 0000000000..f578efc542 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effNet_input_x_1_3_224_224.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effNet_output_f_1_1280_7_7.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effNet_output_f_1_1280_7_7.bin new file mode 100755 index 0000000000..faf89d0cd2 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effNet_output_f_1_1280_7_7.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effNet_output_y_1_1000.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effNet_output_y_1_1000.bin new file mode 100755 index 0000000000..39130e938a Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effNet_output_y_1_1000.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/efficientnet_b0_f.ms b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/efficientnet_b0_f.ms new file mode 100644 index 0000000000..67f49ee814 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/efficientnet_b0_f.ms differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/efficientnet_b0_f.pb.fbs.ms b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/efficientnet_b0_f.pb.fbs.ms new file mode 100644 index 0000000000..e5a0f0cbe6 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/efficientnet_b0_f.pb.fbs.ms differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_dx_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_dx_3_28_28_3.bin new file mode 100644 index 0000000000..03e21edf31 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_dx_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_dy_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_dy_3_28_28_3.bin new file mode 100644 index 0000000000..376f38816f Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_dy_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_x_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_x_3_28_28_3.bin new file mode 100644 index 0000000000..ff6b4c92ff Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_1_x_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s2_dx_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s2_dx_3_28_28_3.bin new file mode 100644 index 0000000000..331b4c0b1f Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s2_dx_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s2_dy_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s2_dy_3_28_28_3.bin new file mode 100644 index 0000000000..49737f4a3b Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s2_dy_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s2_x_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s2_x_3_28_28_3.bin new file mode 100644 index 0000000000..7529679516 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s2_x_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s3_dx_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s3_dx_3_28_28_3.bin new file mode 100644 index 0000000000..2b70e7ce70 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s3_dx_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s3_dy_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s3_dy_3_28_28_3.bin new file mode 100644 index 0000000000..2917f02ca1 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s3_dy_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s3_x_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s3_x_3_28_28_3.bin new file mode 100644 index 0000000000..7529679516 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/avgpoolgradfp32_s3_x_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin index cca67a85df..f9d7f86a28 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dx_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dx_3_28_28_3.bin new file mode 100644 index 0000000000..768d6b1872 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dx_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dy_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dy_1_28_28_3.bin index 15c810365e..ed6176fa92 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dy_1_28_28_3.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dy_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dy_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dy_3_28_28_3.bin new file mode 100644 index 0000000000..fbb2f6fb3f Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_dy_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_x_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_x_1_28_28_3.bin new file mode 100644 index 0000000000..44557eee44 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_x_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_x_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_x_3_28_28_3.bin new file mode 100644 index 0000000000..67e872b6d3 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_x_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_xgrad_1_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_xgrad_1_28_28_3.bin new file mode 100644 index 0000000000..750df112b2 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_xgrad_1_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_xgrad_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_xgrad_3_28_28_3.bin new file mode 100644 index 0000000000..d4970bf50b Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_1_xgrad_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s2_dx_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s2_dx_3_28_28_3.bin new file mode 100644 index 0000000000..79e21434c5 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s2_dx_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s2_dy_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s2_dy_3_28_28_3.bin new file mode 100644 index 0000000000..f4b1092b71 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s2_dy_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s2_x_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s2_x_3_28_28_3.bin new file mode 100644 index 0000000000..37fa25c5fb Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s2_x_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s2_xgrad_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s2_xgrad_3_28_28_3.bin new file mode 100644 index 0000000000..68891f67ce Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s2_xgrad_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s3_dx_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s3_dx_3_28_28_3.bin new file mode 100644 index 0000000000..ee10edd675 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s3_dx_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s3_dy_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s3_dy_3_28_28_3.bin new file mode 100644 index 0000000000..ffa88601f8 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s3_dy_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s3_x_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s3_x_3_28_28_3.bin new file mode 100644 index 0000000000..8b001727eb Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s3_x_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s3_xgrad_3_28_28_3.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s3_xgrad_3_28_28_3.bin new file mode 100644 index 0000000000..2491047492 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/pooling/maxpoolgradfp32_s3_xgrad_3_28_28_3.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/train/train_input_32_1000.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/train/train_input_32_1000.bin index e69de29bb2..27b1155793 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/train/train_input_32_1000.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/train/train_input_32_1000.bin differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/train/train_weight_10_1000.bin b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/train/train_weight_10_1000.bin index e69de29bb2..1bbac89189 100644 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/train/train_weight_10_1000.bin and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/train/train_weight_10_1000.bin differ diff --git a/mindspore/lite/tools/anf_importer/CMakeLists.txt b/mindspore/lite/tools/anf_importer/CMakeLists.txt index 040ebda12d..92488d226f 100644 --- a/mindspore/lite/tools/anf_importer/CMakeLists.txt +++ b/mindspore/lite/tools/anf_importer/CMakeLists.txt @@ -1,6 +1,7 @@ -file(GLOB_RECURSE ANF_IMPORTER_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} +file(GLOB ANF_IMPORTER_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.cc ) + add_library(anf_importer_mid OBJECT ${ANF_IMPORTER_SRC_LIST} ) diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 4148217e70..5cf1c9c0ff 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -25,6 +25,10 @@ namespace mindspore { namespace lite { static const std::vector nhwcOpList = { +#ifdef SUPPORT_TRAIN + schema::PrimitiveType_Conv2DGradFilter, schema::PrimitiveType_Conv2DGradInput, + schema::PrimitiveType_PoolingGrad, schema::PrimitiveType_BiasGrad, +#endif schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_Pooling, schema::PrimitiveType_Resize, diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h index 840fde4412..1f9af6fef6 100644 --- a/mindspore/lite/tools/common/node_util.h +++ b/mindspore/lite/tools/common/node_util.h @@ -34,6 +34,8 @@ std::vector GetInsertOpList(); std::vector GetNhwcOpList(); +std::vector GetNhwcDualInputOpList(); + std::vector Getfp32FullOpList(); std::vector GetUint8NhwcOpList(); diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index de1130fc47..0eb17b0518 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -139,7 +139,12 @@ set(LITE_SRC ${SRC_DIR}/executor.cc ${SRC_DIR}/model.cc ) +if (SUPPORT_TRAIN) +set(LITE_SRC + ${LITE_SRC} + ) +endif () set(ARM_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src/runtime/kernel/arm) file(GLOB KERNEL_SRC ${ARM_DIR}/base/*.cc