diff --git a/mindspore/lite/include/train_session.h b/mindspore/lite/include/train_session.h index f6012ab852..1a641b9119 100644 --- a/mindspore/lite/include/train_session.h +++ b/mindspore/lite/include/train_session.h @@ -30,14 +30,14 @@ namespace session { class TrainSession : public lite::LiteSession { public: TrainSession(); - ~TrainSession() = default; + ~TrainSession(); int RunGraph(const session::KernelCallBack &before = nullptr, const session::KernelCallBack &after = nullptr) override; int CompileGraph(lite::Model *model) override; virtual void ReplaceOps(); - virtual void *ExportToBuf(void *buf, size_t *len) const; + virtual void* ExportToBuf(lite::Model *model, void* buf, size_t* len) const; // todo: output tensors by tensor name std::unordered_map> GetOutputMap() const; diff --git a/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.cc b/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.cc index 375bc976ad..2acc33ee2d 100644 --- a/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.cc +++ b/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.cc @@ -31,9 +31,9 @@ int DoArithmeticSelfGrad(const TensorPtrVector &in_tensors, const TensorPtrVecto mindspore::lite::Allocator *allocator) { size_t data_size = in_tensors[0]->ElementsNum(); OpParameter *param = node->primitive_; - float *dy_data = (float *)in_tensors[0]->data_; - float *x_data = (float *)in_tensors[1]->data_; - float *dx_data = (float *)(float *)out_tensors[0]->data_; + float *dy_data = reinterpret_cast(in_tensors[0]->data_); + float *x_data = reinterpret_cast(in_tensors[1]->data_); + float *dx_data = reinterpret_cast(out_tensors[0]->data_); int ret; if (param->type_ == KernelType::LogGrad) { ret = ElementDiv(dy_data, x_data, dx_data, data_size); diff --git a/mindspore/lite/nnacl/fp32_grad/batch_norm.c b/mindspore/lite/nnacl/fp32_grad/batch_norm.c index 5f511e15c7..bee4bf433d 100644 --- a/mindspore/lite/nnacl/fp32_grad/batch_norm.c +++ b/mindspore/lite/nnacl/fp32_grad/batch_norm.c @@ -28,7 +28,7 @@ void sumSpatialBatch(const float *in, int size, int ch, float *out) { } static void meanVar(const float *in, int size, int ch, float eps, float *mean, float *invar) { - float N = (float)size; + float N = (float)(size); sumSpatialBatch(in, N, ch, mean); for (int f = 0; f < ch; ++f) { mean[f] /= N; diff --git a/mindspore/lite/nnacl/fp32_grad/pooling_grad.c b/mindspore/lite/nnacl/fp32_grad/pooling_grad.c index 58423c192c..87d55504df 100644 --- a/mindspore/lite/nnacl/fp32_grad/pooling_grad.c +++ b/mindspore/lite/nnacl/fp32_grad/pooling_grad.c @@ -31,63 +31,29 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter int output_h = pooling_param->output_h_; int output_batch = pooling_param->output_batch_; - const float *inPtr = NULL; - // for (int i = 0; i < output_h * output_w * channel * output_batch; i++) output_ptr[i] = 0.0; for (int i = 0; i < in_h * in_w * channel * output_batch; i++) output_ptr[i] = 0.0; float kk = (float)(win_h * win_w); - for (uint16_t ib = 0; ib < output_batch; ib++) { - float *out; - // out = &output_ptr[(ib * output_h * output_w)]; - out = &output_ptr[(ib * in_h * in_w * channel)]; - // inPtr = (float *)(&input_ptr[(ib * in_h * in_w)]); - inPtr = (float *)(&input_ptr[(ib * output_h * output_w * channel)]); - if (1) { // in->layout() == Tensor::nhwc) - // iterate over yt - for (uint16_t yh = 0; yh < output_h; yh++) { - for (uint16_t yw = 0; yw < output_w; yw++) { - for (uint16_t ic = 0; ic < channel; ic++) { - int idx = (yw + yh * output_w) * channel + ic; // (ic*in_h*in_w) + (in_w*yh) + yw; - float delta = inPtr[idx] / kk; - for (int32_t kh = 0; kh < win_h; kh++) { - int xh = yh * stride_h + kh - pad_h; - if ((xh < 0) || (xh >= in_h)) { - continue; - } - for (int32_t kw = 0; kw < win_w; kw++) { - int xw = yw * stride_w + kw - pad_w; - if ((xw < 0) || (xw >= in_w)) { - continue; - } - - // out[(xw + output_w * xh) * channel + ic] += delta; - out[(xw + in_w * xh) * channel + ic] += delta; - } + float *out = &output_ptr[(ib * in_h * in_w * channel)]; + const float *inPtr = &input_ptr[(ib * output_h * output_w * channel)]; + // iterate over yt + for (uint16_t yh = 0; yh < output_h; yh++) { + for (uint16_t yw = 0; yw < output_w; yw++) { + for (uint16_t ic = 0; ic < channel; ic++) { + int idx = (yw + yh * output_w) * channel + ic; // (ic*in_h*in_w) + (in_w*yh) + yw; + float delta = inPtr[idx] / kk; + for (int32_t kh = 0; kh < win_h; kh++) { + int xh = yh * stride_h + kh - pad_h; + if ((xh < 0) || (xh >= in_h)) { + continue; } - } - } - } - } else { // nchw - for (uint16_t ic = 0; ic < channel; ic++) { - // iterate over yt - for (uint16_t yh = 0; yh < output_h; yh++) { - for (uint16_t yw = 0; yw < output_w; yw++) { - int idx = (ic * output_h * output_w) + (output_w * yh) + yw; - float delta = inPtr[idx] / kk; - for (int32_t kh = 0; kh < win_h; kh++) { - int xh = yh * stride_h + kh - pad_h; - if ((xh < 0) || (xh >= in_h)) { + for (int32_t kw = 0; kw < win_w; kw++) { + int xw = yw * stride_w + kw - pad_w; + if ((xw < 0) || (xw >= in_w)) { continue; } - for (int32_t kw = 0; kw < win_w; kw++) { - int xw = yw * stride_w + kw - pad_w; - if ((xw < 0) || (xw >= in_w)) { - continue; - } - // out[(ic * output_h * output_w) + (xh * output_w) + xw] += delta; - out[(ic * in_h * in_w) + (xh * in_w) + xw] += delta; - } + out[(xw + in_w * xh) * channel + ic] += delta; } } } @@ -111,73 +77,39 @@ void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy int output_h = pooling_param->output_h_; int output_batch = pooling_param->output_batch_; - const float *inPtr; - const float *dyPtr; - for (int i = 0; i < in_h * in_w * channel * output_batch; i++) output_ptr[i] = 0.0; for (uint16_t ib = 0; ib < output_batch; ib++) { - float *out; - out = &output_ptr[(ib * in_h * in_w * channel)]; - inPtr = (const float *)(&input_ptr[(ib * in_h * in_w * channel)]); - dyPtr = (const float *)(&dy_ptr[(ib * output_h * output_w * channel)]); + float *out = &output_ptr[(ib * in_h * in_w * channel)]; + const float *inPtr = (const float *)(&input_ptr[(ib * in_h * in_w * channel)]); + const float *dyPtr = (const float *)(&dy_ptr[(ib * output_h * output_w * channel)]); - if (1) { // nhwc - for (uint16_t yh = 0; yh < output_h; yh++) { - for (uint16_t yw = 0; yw < output_w; yw++) { - for (uint16_t ic = 0; ic < channel; ic++) { - int idx = (yw + yh * output_w) * channel + ic; + for (uint16_t yh = 0; yh < output_h; yh++) { + for (uint16_t yw = 0; yw < output_w; yw++) { + for (uint16_t ic = 0; ic < channel; ic++) { + int idx = (yw + yh * output_w) * channel + ic; - float delta = dyPtr[idx]; - float max_val = -FLT_MAX; - int max_idx = 0; - for (int32_t kh = 0; kh < win_h; kh++) { - int xh = yh * stride_h + kh - pad_h; - if ((xh < 0) || (xh >= in_h)) { - continue; - } - for (int32_t kw = 0; kw < win_w; kw++) { - int xw = yw * stride_w + kw - pad_w; - if ((xw < 0) || (xw >= in_w)) { - continue; - } - - if (inPtr[(xw + in_w * xh) * channel + ic] > max_val) { - max_val = inPtr[(xw + in_w * xh) * channel + ic]; - max_idx = (xw + in_w * xh) * channel + ic; - } - } + float delta = dyPtr[idx]; + float max_val = -FLT_MAX; + int max_idx = 0; + for (int32_t kh = 0; kh < win_h; kh++) { + int xh = yh * stride_h + kh - pad_h; + if ((xh < 0) || (xh >= in_h)) { + continue; } - out[max_idx] += delta; - } - } - } - } else { // nchw - for (uint16_t yh = 0; yh < output_h; yh++) { - for (uint16_t yw = 0; yw < output_w; yw++) { - for (uint16_t ic = 0; ic < channel; ic++) { - int idx = (ic * output_h * output_w) + (output_w * yh) + yw; - float delta = dyPtr[idx]; - float max_val = -FLT_MAX; - int max_idx = 0; - for (int32_t kh = 0; kh < win_h; kh++) { - int xh = yh * stride_h + kh - pad_h; - if ((xh < 0) || (xh >= in_h)) { + for (int32_t kw = 0; kw < win_w; kw++) { + int xw = yw * stride_w + kw - pad_w; + if ((xw < 0) || (xw >= in_w)) { continue; } - for (int32_t kw = 0; kw < win_w; kw++) { - int xw = yw * stride_w + kw - pad_w; - if ((xw < 0) || (xw >= in_w)) { - continue; - } - if (inPtr[(ic * in_h * in_w) + (xh * in_w) + xw] > max_val) { - max_val = inPtr[(ic * in_h * in_w) + (xh * in_w) + xw]; - max_idx = (ic * in_h * in_w) + (xh * in_w) + xw; - } + + if (inPtr[(xw + in_w * xh) * channel + ic] > max_val) { + max_val = inPtr[(xw + in_w * xh) * channel + ic]; + max_idx = (xw + in_w * xh) * channel + ic; } } - out[max_idx] += delta; } + out[max_idx] += delta; } } } diff --git a/mindspore/lite/src/common/file_utils_ext.cc b/mindspore/lite/src/common/file_utils_ext.cc index 39e9983ced..689e09421a 100644 --- a/mindspore/lite/src/common/file_utils_ext.cc +++ b/mindspore/lite/src/common/file_utils_ext.cc @@ -42,11 +42,9 @@ int CompareRelativeOutput(float *output_data, std::string file_path) { size_t output_size; auto ground_truth = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); size_t output_num = output_size / sizeof(float); - // std::cout << "output num : " << output_num << "\n"; int error = CompareOutputRelativeData(output_data, ground_truth, output_num); delete [] ground_truth; if (error > 1e-4) { - std::cout << "has accuracy error!\n" << error << "\n"; return 1; } return 0; @@ -56,7 +54,6 @@ float RelativeOutputError(float *output_data, std::string file_path) { size_t output_size; auto ground_truth = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); size_t output_num = output_size / sizeof(float); - std::cout << "output num : " << output_num << "\n"; float error = CompareOutputRelativeData(output_data, ground_truth, output_num); delete [] ground_truth; return error; diff --git a/mindspore/lite/src/executor.cc b/mindspore/lite/src/executor.cc index 330ab443ea..fcd53a46a1 100644 --- a/mindspore/lite/src/executor.cc +++ b/mindspore/lite/src/executor.cc @@ -51,8 +51,6 @@ int Executor::Run(std::vector &in_tensors, std::vector &out_ MS_LOG(ERROR) << "run kernel before_callback failed, name: " << kernel->name(); } } - // JBDEBUG - // std::cout << "executing kernel " << kernel->name() << "\n"; auto ret = kernel->Run(); if (0 != ret) { MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc index 41e45c0d82..7ade3be4fb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc @@ -68,19 +68,12 @@ int ApplyMomentumCPUKernel::Init() { for (size_t i = 0; i < elem_num; i++) accumulate[i] = 0.0; workspace = new float[elem_num]; - return 0; -} -#if 0 -OpParameter *PopulateApplyMomentumParameter(const lite::Primitive *primitive) { - OpParameter *param = new (std::nothrow) OpParameter(); - if (param == nullptr) { - MS_LOG(ERROR) << "new Param for OptMomentum failed."; - return nullptr; + if (workspace == nullptr) { + MS_LOG(ERROR) << "apply momentum workspace fail to malloc!"; + return RET_ERROR; } - param->type_ = primitive->Type(); - return param; + return 0; } -#endif kernel::LiteKernel *CpuApplyMomentumFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.h index f688fb3fcb..121248ed3c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.h @@ -27,8 +27,11 @@ class ApplyMomentumCPUKernel : public LiteKernel { explicit ApplyMomentumCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} - ~ApplyMomentumCPUKernel() override { delete[] workspace; } + : LiteKernel(parameter, inputs, outputs, ctx, primitive), workspace(nullptr) {} + ~ApplyMomentumCPUKernel() override { + if (workspace) + delete[] workspace; + } int Init() override; int ReSize() override; @@ -38,8 +41,6 @@ class ApplyMomentumCPUKernel : public LiteKernel { float *workspace; }; -// OpParameter *PopulateApplyMomentumParameter(const lite::Primitive *primitive); - } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_APPLY_MOMENTUM_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc index d23304f09e..1c34c40e18 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_grad.cc @@ -58,7 +58,6 @@ int ArithmeticGradCPUKernel::Init() { tile_data1 = new (std::nothrow) float[in_tensors_.at(0)->ElementsNum()]; if (tile_data1 == nullptr) { MS_LOG(ERROR) << "new data1 fail!"; - delete tile_data0; return RET_ERROR; } @@ -66,8 +65,6 @@ int ArithmeticGradCPUKernel::Init() { tile_data2 = new (std::nothrow) float[in_tensors_.at(0)->ElementsNum()]; if (tile_data2 == nullptr) { MS_LOG(ERROR) << "new data2 fail!"; - delete tile_data0; - delete tile_data1; return RET_ERROR; } } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc index cbb78c3d32..fcbfd02bd7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc @@ -29,30 +29,8 @@ using mindspore::lite::RET_OK; // using mindspore::lite::REG_OP; using mindspore::schema::PrimitiveType_BNGrad; -/* -{dy} -{x } -{scale } -{save_mean } -{save_inv_variance } -*/ -namespace mindspore::kernel { - -#if 0 -OpParameter *PopulateBNGradParameter(const lite::Primitive *primitive) { - BNGradParameter *param = new (std::nothrow) BNGradParameter(); - if (param == nullptr) { - MS_LOG(ERROR) << "new Param for conv grad filter failed."; - return nullptr; - } - param->op_parameter_.type_ = primitive->Type(); - auto bngrad_primitive = primitive->Value()->value_as_BNGrad(); - param->epsilon_ = bngrad_primitive->eps(); - param->momentum_ = bngrad_primitive->momentum(); - return reinterpret_cast(param); -} -#endif +namespace mindspore::kernel { int BNGradCPUKernel::Init() { auto *input_x = in_tensors_.at(1); int channels = input_x->shape().at(kNHWC_C); @@ -68,7 +46,6 @@ int BNGradCPUKernel::Init() { int BNGradCPUKernel::ReSize() { return RET_OK; } int BNGradCPUKernel::Run() { - // std::cout << "run succ" << std::endl; auto prepare_ret = Prepare(); if (prepare_ret != RET_OK) { MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.h index 827e4b88e1..e5055caa61 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.h @@ -28,8 +28,12 @@ class BNGradCPUKernel : public LiteKernel { explicit BNGradCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} - ~BNGradCPUKernel() override { delete [] workspace; } + : LiteKernel(parameter, inputs, outputs, ctx, primitive), workspace(nullptr), + workspace_size(0) {} + ~BNGradCPUKernel() override { + if (workspace) + delete [] workspace; + } int Init() override; int ReSize() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc index f84e0f3ab9..d4b1a89609 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc @@ -47,6 +47,10 @@ int ConvolutionTrainCPUKernel::Init() { conv_param_->input_channel_ / conv_param_->group_; workspace = new (std::nothrow) float[ws_size]; + if (workspace == nullptr) { + MS_LOG(ERROR) << "new workspace fail!"; + return RET_ERROR; + } return RET_OK; } @@ -95,8 +99,6 @@ int ConvolutionTrainCPUKernel::Run() { gemm(0, 1, m, n, k, 1, mat_a, k, mat_b, k, 1, mat_c, out_ch); } } - - // std::cout << "run succ" << std::endl; return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.h index f5947b324f..7cd9cbaf54 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.h @@ -27,8 +27,11 @@ class ConvolutionTrainCPUKernel : public LiteKernel { explicit ConvolutionTrainCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx, const lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} - ~ConvolutionTrainCPUKernel() override { delete[] workspace; } + : LiteKernel(parameter, inputs, outputs, ctx, primitive), workspace(nullptr) {} + ~ConvolutionTrainCPUKernel() override { + if (workspace) + delete[] workspace; + } int Init() override; int ReSize() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc index 3134deec3c..4b82b7a814 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc @@ -37,10 +37,7 @@ int ConvolutionGradFilterCPUKernel::Init() { MS_ASSERT(x_tensor != nullptr); auto *dy_tensor = in_tensors_.at(0); MS_ASSERT(dy_tensor != nullptr); -#if 0 - auto *weight_tensor = out_tensors_.at(0); - MS_ASSERT(weight_tensor != nullptr); -#endif + auto conv_param = reinterpret_cast(op_parameter_); conv_param->output_batch_ = dy_tensor->shape().at(kNHWC_N); conv_param->input_batch_ = x_tensor->shape().at(kNHWC_N); @@ -49,7 +46,7 @@ int ConvolutionGradFilterCPUKernel::Init() { // assume OutCh|kh|kw|InCh conv_param->input_channel_ = x_tensor->shape().at(kNHWC_C); conv_param->output_channel_ = dy_tensor->shape().at(kNHWC_C); - // TBD + conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; @@ -113,52 +110,9 @@ int ConvolutionGradFilterCPUKernel::Run() { gemm(1, 1, k, n, m, 1, mat_a, out_ch, mat_b, m, 1, mat_c, n); } } - - // std::cout << "run succ" << std::endl; return RET_OK; } -#if 0 -OpParameter *PopulateConvolutionGradFilterParameter(const lite::Primitive *primitive) { - ConvParameter *param = new (std::nothrow) ConvParameter(); - if (param == nullptr) { - MS_LOG(ERROR) << "new Param for conv grad filter failed."; - return nullptr; - } - param->op_parameter_.type_ = primitive->Type(); - - auto convg_primitive = primitive->Value()->value_as_Conv2DGradFilter(); - param->kernel_h_ = convg_primitive->kernelH(); - param->kernel_w_ = convg_primitive->kernelW(); - param->stride_h_ = convg_primitive->strideH(); - param->stride_w_ = convg_primitive->strideW(); - param->dilation_h_ = convg_primitive->dilateH(); - param->dilation_w_ = convg_primitive->dilateW(); - param->pad_h_ = convg_primitive->padUp(); - param->pad_w_ = convg_primitive->padLeft(); - param->pad_u_ = convg_primitive->padUp(); - param->pad_d_ = convg_primitive->padDown(); - param->pad_l_ = convg_primitive->padLeft(); - param->pad_r_ = convg_primitive->padRight(); - param->group_ = convg_primitive->group(); - auto act_type = convg_primitive->activationType(); - switch (act_type) { - case schema::ActivationType_RELU: - param->is_relu_ = true; - param->is_relu6_ = false; - break; - case schema::ActivationType_RELU6: - param->is_relu_ = false; - param->is_relu6_ = true; - break; - default: - param->is_relu_ = false; - param->is_relu6_ = false; - break; - } - return reinterpret_cast(param); -} -#endif kernel::LiteKernel *CpuConvGradFilterFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h index 5d1da23efc..ea2f4aa825 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h @@ -27,8 +27,11 @@ class ConvolutionGradFilterCPUKernel : public LiteKernel { explicit ConvolutionGradFilterCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} - ~ConvolutionGradFilterCPUKernel() override { delete[] workspace; } + : LiteKernel(parameter, inputs, outputs, ctx, primitive), workspace(nullptr) {} + ~ConvolutionGradFilterCPUKernel() override { + if (workspace) + delete[] workspace; + } int Init() override; int ReSize() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc index d31a14bcbb..0320bbf430 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc @@ -47,7 +47,6 @@ int ConvolutionGradInputCPUKernel::Init() { conv_param->input_channel_ = dx_tensor->shape()[(kNHWC_C)]; conv_param->output_channel_ = weight_tensor->shape()[(kNHWC_N)]; - // TBD conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; @@ -59,7 +58,7 @@ int ConvolutionGradInputCPUKernel::Init() { MS_LOG(ERROR) << "new workspace fail!"; return RET_ERROR; } - return 0; + return RET_OK; } int ConvolutionGradInputCPUKernel::ReSize() { return 0; } @@ -108,53 +107,8 @@ int ConvolutionGradInputCPUKernel::Run() { col2im_hwc(mat_c, dx_addr + (i * groups) * (in_ch / groups) * in_h * in_w + j * (in_ch / groups), conv_param); } } - - // std::cout << "run succ" << std::endl; - return 0; -} - -#if 0 -OpParameter *PopulateConvolutionGradInputParameter(const lite::Primitive *primitive) { - ConvParameter *param = new (std::nothrow) ConvParameter(); - if (param == nullptr) { - MS_LOG(ERROR) << "new Param for conv grad input failed."; - return nullptr; - } - param->op_parameter_.type_ = primitive->Type(); - - auto convg_primitive = primitive->Value()->value_as_Conv2DGradInput(); - param->kernel_h_ = convg_primitive->kernelH(); - param->kernel_w_ = convg_primitive->kernelW(); - param->stride_h_ = convg_primitive->strideH(); - param->stride_w_ = convg_primitive->strideW(); - param->dilation_h_ = convg_primitive->dilateH(); - param->dilation_w_ = convg_primitive->dilateW(); - param->pad_h_ = convg_primitive->padUp(); - param->pad_w_ = convg_primitive->padLeft(); - param->pad_u_ = convg_primitive->padUp(); - param->pad_d_ = convg_primitive->padDown(); - param->pad_l_ = convg_primitive->padLeft(); - param->pad_r_ = convg_primitive->padRight(); - param->group_ = convg_primitive->group(); - auto act_type = convg_primitive->activationType(); - switch (act_type) { - case schema::ActivationType_RELU: - param->is_relu_ = true; - param->is_relu6_ = false; - break; - case schema::ActivationType_RELU6: - param->is_relu_ = false; - param->is_relu6_ = true; - break; - default: - param->is_relu_ = false; - param->is_relu6_ = false; - break; - } - - return reinterpret_cast(param); + return RET_OK; } -#endif kernel::LiteKernel *CpuConvGradInputFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h index 3c091f3e82..0090608f32 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h @@ -27,8 +27,11 @@ class ConvolutionGradInputCPUKernel : public LiteKernel { explicit ConvolutionGradInputCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} - ~ConvolutionGradInputCPUKernel() override { delete[] workspace; } + : LiteKernel(parameter, inputs, outputs, ctx, primitive), workspace(nullptr) {} + ~ConvolutionGradInputCPUKernel() override { + if (workspace) + delete[] workspace; + } int Init() override; int ReSize() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc index ccee7786e6..33a1cc872c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc @@ -20,7 +20,6 @@ #include "nnacl/fp32/pooling.h" #include "nnacl/fp32_grad/pooling_grad.h" #include "include/errorcode.h" -// #include "src/train/ops/train_ops.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -51,11 +50,6 @@ int PoolingGradCPUKernel::Init() { pool_param->input_w_ = in_shape[kNHWC_W]; pool_param->input_batch_ = in_shape[kNHWC_N]; pool_param->input_channel_ = in_shape[kNHWC_C]; - - // Emir -- here I assume we get the outputshape in the output tensor - // auto *out_tensor = out_tensors_.front(); - // auto out_shape = in_tensors_.at(1)->shape(); - pool_param->output_h_ = out_shape[kNHWC_H]; pool_param->output_w_ = out_shape[kNHWC_W]; pool_param->output_batch_ = out_shape[kNHWC_N]; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc index f3ee83274c..d608350545 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc @@ -55,53 +55,6 @@ void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *lab } output2[0] = total_loss / param_->batch_size_; } - -#if 0 -void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int *labels, const float *losses, - float *output) const { - float total_loss = 0; - for (int i = 0; i < param_->batch_size_; ++i) { - if (labels[i] < 0) { - MS_LOG(EXCEPTION) << "label value must >= 0"; - } - size_t label = labels[i]; - if (label > param->number_of_classes_) { - MS_LOG(EXCEPTION) << "error label input!"; - } else { - total_loss -= logf(losses[i * param->number_of_classes_ + label]); - } - } - output[0] = total_loss / param->batch_size_; -} - -void SoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *labels, const float *losses, float *grads, - float *output) const { - size_t row_start = 0; - float total_loss = 0; - for (int i = 0; i < param->batch_size_; ++i) { - if (labels[i] < 0) { - MS_LOG(EXCEPTION) << "label value must >= 0"; - } - size_t label = labels[i]; - if (label > param->number_of_classes_) { - MS_LOG(EXCEPTION) << "error label input!"; - } else { - total_loss -= logf(losses[i * param->number_of_classes_ + label]); - for (size_t j = 0; j < param->number_of_classes_; ++j) { - size_t index = row_start + j; - if (j == label) { - grads[index] = (losses[index] - 1) / param->batch_size_; - } else { - grads[index] = losses[index] / param->batch_size_; - } - } - } - row_start += param->number_of_classes_; - } - output[0] = total_loss / param->batch_size_; -} -#endif - int SoftmaxCrossEntropyWithLogitsCPUKernel::Run() { auto ret = Prepare(); if (ret != RET_OK) { @@ -117,11 +70,6 @@ int SoftmaxCrossEntropyWithLogitsCPUKernel::Run() { grads = reinterpret_cast(out_tensors_.at(1)->MutableData()); } size_t data_size = in_tensors_.at(0)->ElementsNum(); - float *losses = new (std::nothrow) float[data_size]; - if (losses == nullptr) { - MS_LOG(ERROR) << "losses is null"; - return RET_ERROR; - } MS_ASSERT(out != nullptr); MS_ASSERT(labels != nullptr); @@ -151,9 +99,16 @@ int SoftmaxCrossEntropyWithLogitsCPUKernel::Init() { size_t data_size = in_tensors_.at(0)->ElementsNum(); losses_ = new (std::nothrow) float[data_size]; + if (losses_ == nullptr) { + MS_LOG(ERROR) << "failed to malloc losses!"; + return RET_ERROR; + } + sum_data_ = new (std::nothrow) float[dims[0]]; - MS_ASSERT(losses_ != nullptr); - MS_ASSERT(sum_data_ != nullptr); + if (sum_data_ == nullptr) { + MS_LOG(ERROR) << "failed to malloc sum_data_!"; + return RET_ERROR; + } sm_params_.n_dim_ = 2; sm_params_.element_size_ = data_size; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.h index 6eaeb5a4d1..123f9d48e7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.h @@ -33,12 +33,14 @@ class SoftmaxCrossEntropyWithLogitsCPUKernel : public LossKernel { const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) - : LossKernel(parameter, inputs, outputs, ctx, primitive) { + : LossKernel(parameter, inputs, outputs, ctx, primitive), losses_(nullptr), sum_data_(nullptr) { param_ = reinterpret_cast(parameter); } ~SoftmaxCrossEntropyWithLogitsCPUKernel() override { - delete[] losses_; - delete[] sum_data_; + if (losses_) + delete[] losses_; + if (sum_data_) + delete[] sum_data_; } void ForwardPostExecute(const float *labels, const float *logits, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_grad.cc index cb09b077f7..51f4b93272 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_grad.cc @@ -30,8 +30,6 @@ using mindspore::lite::RET_OK; namespace mindspore::kernel { int SoftmaxGradCPUKernel::Init() { - // auto input_tensor =in_tensors_.at(0); - param = reinterpret_cast(op_parameter_); auto in_shape = in_tensors_.at(0)->shape(); auto in_dims = in_shape.size(); @@ -43,7 +41,6 @@ int SoftmaxGradCPUKernel::Init() { } param->element_size_ = ele_size; - // malloc tmp buffer auto axis = param->axis_; if ((axis < -1) || (axis > param->n_dim_)) { MS_LOG(ERROR) << "SoftmaxGrad axis is invalid!"; @@ -57,9 +54,17 @@ int SoftmaxGradCPUKernel::Init() { } sum_data_ = new (std::nothrow) float[inner_size]; - MS_ASSERT(sum_data_ != nullptr); + if (sum_data_ == nullptr) { + MS_LOG(ERROR) << "failed to malloc sum_data_!"; + return RET_ERROR; + } + sum_mul_ = new (std::nothrow) float[inner_size * in_shape[axis]]; - MS_ASSERT(sum_mul_ != nullptr); + if (sum_mul_ == nullptr) { + MS_LOG(ERROR) << "failed to malloc sum_mul_!"; + return RET_ERROR; + } + return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_grad.h index c35271a2f1..b70f1a1bff 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_grad.h @@ -28,11 +28,15 @@ class SoftmaxGradCPUKernel : public LiteKernel { explicit SoftmaxGradCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx, const lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + : LiteKernel(parameter, inputs, outputs, ctx, primitive), sum_data_(nullptr), sum_mul_(nullptr) { param = reinterpret_cast(parameter); } - ~SoftmaxGradCPUKernel() override = default; - + ~SoftmaxGradCPUKernel() override { + if (sum_data_) + delete[] sum_data_; + if (sum_mul_) + delete[] sum_mul_; + } int Init() override; int ReSize() override; int Run() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc index b9d38ad9b8..fbbe3f88dd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc @@ -89,12 +89,6 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { grads = reinterpret_cast(out_tensors_.at(1)->MutableData()); } size_t data_size = in_tensors_.at(0)->ElementsNum(); - float *losses = new (std::nothrow) float[data_size]; - if (losses == nullptr) { - MS_LOG(ERROR) << "losses is null"; - return RET_ERROR; - } - MS_ASSERT(out != nullptr); MS_ASSERT(labels != nullptr); MS_ASSERT(ins != nullptr); @@ -128,12 +122,18 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Init() { MS_LOG(ERROR) << "softmax etropy loss in0 have no data"; return RET_ERROR; } - size_t data_size = in_tensors_.at(0)->ElementsNum(); losses_ = new (std::nothrow) float[data_size]; + if (losses_ == nullptr) { + MS_LOG(ERROR) << "failed to malloc losses!"; + return RET_ERROR; + } + sum_data_ = new (std::nothrow) float[dims[0]]; - MS_ASSERT(losses_ != nullptr); - MS_ASSERT(sum_data_ != nullptr); + if (sum_data_ == nullptr) { + MS_LOG(ERROR) << "failed to malloc sum_data_!"; + return RET_ERROR; + } sm_params_.n_dim_ = 2; sm_params_.element_size_ = data_size; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h index 991329736b..19d12ea8fb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h @@ -33,12 +33,14 @@ class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LossKernel { const std::vector &outputs, const lite::Context *ctx, const mindspore::lite::PrimitiveC *primitive) - : LossKernel(parameter, inputs, outputs, ctx, primitive) { + : LossKernel(parameter, inputs, outputs, ctx, primitive) , losses_(nullptr), sum_data_(nullptr) { param = reinterpret_cast(parameter); } ~SparseSoftmaxCrossEntropyWithLogitsCPUKernel() override { - delete[] losses_; - delete[] sum_data_; + if (losses_) + delete[] losses_; + if (sum_data_) + delete[] sum_data_; } void ForwardPostExecute(const int *labels, const float *losses, float *output) const; diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index 315042502a..e3f2ca7101 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -47,9 +47,19 @@ int TrainSession::CompileGraph(lite::Model *model) { return LiteSession::CompileGraph(model); } -void *TrainSession::ExportToBuf(void *buf, size_t *len) const { - // auto train_model_impl = (dynamic_cast(model_->model_impl())); - // return train_model_impl->ExportToBuf(buf, len); +TrainSession::~TrainSession() { + for (auto it1 = ext_output_map_.begin(); it1 != ext_output_map_.end(); ++it1) { + if ((output_node_map_.find(it1->first) == output_node_map_.end()) || train_mode_) { + // Delete if not from output_node_map_ + auto tensor_ptr = it1->second.back(); + delete tensor_ptr; + it1->second.pop_back(); + } + } +} + +void *TrainSession::ExportToBuf(lite::Model *model, void *buf, size_t *len) const { + // return model->ExportBuf(buf, len); return nullptr; } @@ -61,7 +71,7 @@ int TrainSession::RunGraph(const session::KernelCallBack &before, const session: if (train_mode_) return LiteSession::RunGraph(before, after); // object is expected to run only inference part of graph - // prepare a lit of kernels till the loss function -- temporary solution + // prepare a list of kernels till the loss function -- temporary solution std::vector infference_kernels; for (auto kernel : this->kernels_) { if (dynamic_cast(kernel) != nullptr) break; @@ -86,8 +96,16 @@ void TrainSession::train() { MS_ASSERT(nullptr != kernel); kernel->train(); } - train_mode_ = true; + for (auto it1 = ext_output_map_.begin(); it1 != ext_output_map_.end(); ++it1) { + if ((output_node_map_.find(it1->first) == output_node_map_.end()) || train_mode_) { + // Delete if not from output_node_map_ + auto tensor_ptr = it1->second.back(); + delete tensor_ptr; + it1->second.pop_back(); + } + } ext_output_map_.clear(); + train_mode_ = true; for (auto kernel : this->kernels_) { if (dynamic_cast(kernel) != nullptr) { auto *ms_tensor = new lite::Tensor(*kernel->out_tensors().at(0)); @@ -101,14 +119,23 @@ void TrainSession::eval() { MS_ASSERT(nullptr != kernel); kernel->eval(); } - train_mode_ = false; kernel::LiteKernel *last_kernel = nullptr; - // We should get in_kernels and then get all last kernels + for (auto it1 = ext_output_map_.begin(); it1 != ext_output_map_.end(); ++it1) { + if ((output_node_map_.find(it1->first) == output_node_map_.end()) || train_mode_) { + // Delete if not from output_node_map_ + auto tensor_ptr = it1->second.back(); + delete tensor_ptr; + it1->second.pop_back(); + } + } ext_output_map_ = output_node_map_; + train_mode_ = false; for (auto kernel : this->kernels_) { if ((dynamic_cast(kernel) != nullptr) && (last_kernel != nullptr)) { - auto *ms_tensor = new lite::Tensor(*last_kernel->out_tensors().at(0)); - ext_output_map_[last_kernel->name()].emplace_back(ms_tensor); + if (ext_output_map_.find(last_kernel->name()) == ext_output_map_.end()) { + auto *ms_tensor = new lite::Tensor(*last_kernel->out_tensors().at(0)); + ext_output_map_[last_kernel->name()].emplace_back(ms_tensor); + } } last_kernel = kernel; } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc index a44f4e710f..db58fe1549 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc @@ -110,4 +110,18 @@ TEST_F(TestBNGradFp32, BNGradFp32) { delete kernel_obj; MS_LOG(INFO) << "BNGradFp32 passed"; } + +#if 0 +TEST_F(TestBNGradFp32, BNTtrainFp32) { + auto bn_param = static_cast(malloc(sizeof(BNGradParameter))); + bn_param->epsilon_ = 0.00001; + bn_param->momentum_ = 0.1; + const int batch = 2; + const int channels = 3; + const int height = 4; + const int width = 5; + auto x_tensor = CreateInTensor("./test_data/bngrad/input_x_2_4_5_3.bin", {batch, height, width, channels}); + std::vector inputs = {x_tensor, x_tensor, scale_tensor, mean_tensor, var_tensor}; +} +#endif } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc index 637fce5a8c..c5ef47ccb7 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc @@ -73,7 +73,6 @@ class NetworkTest : public mindspore::CommonTest { // +-------------+ | // V dw(9) | // +-----------Update-----+ -#if 0 TEST_F(NetworkTest, tuning_layer) { const int BATCH_SIZE = 32; const int NUM_CLASSES = 10; @@ -248,12 +247,15 @@ TEST_F(NetworkTest, tuning_layer) { label->nodeType = schema::NodeType::NodeType_ValueNode; label->format = schema::Format_NHWC; label->dataType = TypeId::kNumberTypeInt32; - label->dims = {BATCH_SIZE}; + label->dims = {BATCH_SIZE*NUM_CLASSES}; label->offset = -1; - label->data.resize(BATCH_SIZE * NUM_CLASSES * sizeof(float)); - int *data = reinterpret_cast(label->data.data()); - for (int i = 0; i < BATCH_SIZE; i++) - for (int j = 0; j < NUM_CLASSES; j++) *(data + i * NUM_CLASSES + j) = j; + // label->data.resize(BATCH_SIZE * NUM_CLASSES * sizeof(float)); + // int *data = reinterpret_cast(label->data.data()); + // for (int i = 0; i < BATCH_SIZE; i++) { + // for (int j = 0; j < NUM_CLASSES; j++) { + // *(data + i * NUM_CLASSES + j) = j; + // } + // } meta_graph->allTensors.emplace_back(std::move(label)); } // tensor 7 - Softmaxentropy @@ -378,6 +380,7 @@ TEST_F(NetworkTest, tuning_layer) { auto ret = session->CompileGraph(model); ASSERT_EQ(lite::RET_OK, ret); session->train(); + session->train(); // Just double check that calling train twice does not cause a problem auto inputs = session->GetInputs(); ASSERT_EQ(inputs.size(), 2); @@ -397,7 +400,7 @@ TEST_F(NetworkTest, tuning_layer) { delete [] buf; auto labelTensor = inputs.at(1); ASSERT_NE(nullptr, labelTensor); - ASSERT_EQ(BATCH_SIZE, labelTensor->ElementsNum()); + ASSERT_EQ(BATCH_SIZE*NUM_CLASSES, labelTensor->ElementsNum()); auto labels = reinterpret_cast(labelTensor->MutableData()); for (int i = 0; i < BATCH_SIZE; i++) labels[i] = (i * 97) % NUM_CLASSES; @@ -411,32 +414,67 @@ TEST_F(NetworkTest, tuning_layer) { auto *outData = reinterpret_cast(outTensor->MutableData()); ASSERT_NE(nullptr, outData); std::cout << "==============Initial=Scores===================" << std::endl; - for (int i = 0; i < 20; i++) { + for (int i = 0; i < 10; i++) { std::cout << outData[i] << ", "; } std::cout << std::endl; + session->eval(); + session->eval(); // Just double check that calling eval twice does not cause a problem ret = session->RunGraph(); outputs = session->GetOutputsByName("BiasAdd"); ASSERT_EQ(outputs.size(), 1); outTensor = (outputs.at(0)); ASSERT_NE(nullptr, outTensor); - // ASSERT_EQ(28 * 28 * 32, outTensor->ElementsNum()); ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type()); outData = reinterpret_cast(outTensor->MutableData()); ASSERT_NE(nullptr, outData); std::cout << "==============Scores=after-single=train========" << std::endl; - for (int i = 0; i < 20; i++) { + for (int i = 0; i < 10; i++) { std::cout << outData[i] << ", "; } std::string output_path = "./test_data/train/train_output_32_10.bin"; auto error = lite::RelativeOutputError(outData, output_path); EXPECT_LT(error, 2e-3); - MS_LOG(INFO) << "TuningLayer passed"; + + ret = session->RunGraph(); + outputs = session->GetOutputsByName("BiasAdd"); + ASSERT_EQ(outputs.size(), 1); + outTensor = (outputs.at(0)); + ASSERT_NE(nullptr, outTensor); + ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type()); + outData = reinterpret_cast(outTensor->MutableData()); + ASSERT_NE(nullptr, outData); + std::cout << "==============Scores=eval-second-time==========" << std::endl; + for (int i = 0; i < 10; i++) { + std::cout << outData[i] << ", "; + } + error = lite::RelativeOutputError(outData, output_path); + EXPECT_LT(error, 2e-3); + + session->train(); + session->eval(); // do some more zig-zags + ret = session->RunGraph(); + outputs = session->GetOutputsByName("BiasAdd"); + ASSERT_EQ(outputs.size(), 1); + outTensor = (outputs.at(0)); + ASSERT_NE(nullptr, outTensor); + ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type()); + outData = reinterpret_cast(outTensor->MutableData()); + ASSERT_NE(nullptr, outData); + std::cout << "==============Scores=Just Checking 3rd time====" << std::endl; + for (int i = 0; i < 10; i++) { + std::cout << outData[i] << ", "; + } + error = lite::RelativeOutputError(outData, output_path); + EXPECT_LT(error, 2e-3); + + delete model; delete session; + MS_LOG(INFO) << "TuningLayer passed"; } -#endif + int32_t fileIterator(mindspore::session::TrainSession *session, const std::string &path, std::function cb) { int32_t res = 0; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc index cb4d2a421f..6011e6b101 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc @@ -30,7 +30,7 @@ class TestSoftmaxCrossEntropyFp32 : public mindspore::CommonTest { TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { // prepare stage - SoftmaxCrossEntropyParameter *sce_param = new SoftmaxCrossEntropyParameter(); + auto sce_param = reinterpret_cast(malloc(sizeof(SoftmaxCrossEntropyParameter))); size_t input_size; std::string input_path = "./test_data/operators/sce_fp32_1_y_6_4.bin"; @@ -83,9 +83,16 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { std::string grad_path = "./test_data/operators/sce_fp32_1_dy_6_4.bin"; lite::CompareOutput(grad, grad_path); - delete sce_param; - l_tensor.SetData(NULL); - y_tensor.SetData(NULL); + delete [] ll_labels; + delete [] labels; + delete [] input_data; + delete [] loss; + delete [] grad; + l_tensor.SetData(nullptr); + y_tensor.SetData(nullptr); + loss_tensor.SetData(nullptr); + grad_tensor.SetData(nullptr); + delete kernel_obj; MS_LOG(INFO) << "SoftmaxCrossEntropyFp32 passed"; }