Browse Source

!13073 [MS_LITE] fix_bug_for_sec_to_master

From: @YeFeng_24
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
f9bd3eae41
15 changed files with 23 additions and 60 deletions
  1. +1
    -0
      mindspore/lite/examples/train_lenet/src/net_runner.cc
  2. +5
    -31
      mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc
  3. +2
    -1
      mindspore/lite/src/runtime/kernel/arm/base/crop_base.h
  4. +0
    -1
      mindspore/lite/src/runtime/kernel/arm/fp16/crop_fp16.cc
  5. +0
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/crop_fp32.cc
  6. +0
    -12
      mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.cc
  7. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.h
  8. +1
    -6
      mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc
  9. +1
    -2
      mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc
  10. +7
    -0
      mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc
  11. +1
    -1
      mindspore/lite/src/tensor.cc
  12. +1
    -1
      mindspore/lite/tools/common/protobuf_utils.cc
  13. +1
    -1
      mindspore/lite/tools/converter/quantizer/quantize_util.cc
  14. +1
    -1
      mindspore/lite/tools/converter/quantizer/quantize_util.h
  15. +1
    -1
      mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc

+ 1
- 0
mindspore/lite/examples/train_lenet/src/net_runner.cc View File

@@ -45,6 +45,7 @@ class Rescaler : public mindspore::session::TrainLoopCallBack {
explicit Rescaler(float scale) : scale_(scale) {
if (scale_ == 0) scale_ = 1.0;
}
~Rescaler() override = default;
void StepBegin(const mindspore::session::TrainLoopCallBackData &cb_data) override {
auto inputs = cb_data.session_->GetInputs();
auto *input_data = reinterpret_cast<float *>(inputs.at(0)->MutableData());


+ 5
- 31
mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc View File

@@ -32,26 +32,11 @@ int CropBaseCPUKernel::Init() { return RET_OK; }
int CropBaseCPUKernel::ReSize() {
auto *input_tensor = in_tensors_.at(kInputIndex);
auto *out_tensor = out_tensors_.at(kOutputIndex);
auto input_shape = input_tensor->shape();
auto output_shape = out_tensor->shape();
size_t input_dim = input_shape.size();
size_t output_dim = output_shape.size();
FreeTmpBuffer();

crop_para_->in_shape_ = reinterpret_cast<int *>(malloc(input_dim * sizeof(int)));
if (crop_para_->in_shape_ == nullptr) {
MS_LOG(ERROR) << "in_shape_ is nullptr";
return RET_ERROR;
}
memcpy(crop_para_->in_shape_, input_shape.data(), sizeof(int) * input_dim);

crop_para_->out_shape_ = reinterpret_cast<int *>(malloc(output_dim * sizeof(int)));
if (crop_para_->out_shape_ == nullptr) {
MS_LOG(ERROR) << "out_shape_ is nullptr";
return RET_ERROR;
}
memcpy(crop_para_->out_shape_, output_shape.data(), sizeof(int) * output_dim);

input_shape_ = input_tensor->shape();
output_shape_ = out_tensor->shape();
size_t input_dim = input_shape_.size();
crop_para_->in_shape_ = input_shape_.data();
crop_para_->out_shape_ = output_shape_.data();
MS_ASSERT(input_dim <= CROP_OFFSET_MAX_SIZE);
crop_para_->input_dim_ = input_dim;
PadOffset(input_dim, crop_para_);
@@ -77,15 +62,4 @@ void CropBaseCPUKernel::PadOffset(int input_dim, CropParameter *crop_para) {
crop_para->in_offset_[i] = crop_offset;
}
}

void CropBaseCPUKernel::FreeTmpBuffer() {
if (crop_para_->in_shape_ != nullptr) {
free(crop_para_->in_shape_);
crop_para_->in_shape_ = nullptr;
}
if (crop_para_->out_shape_ != nullptr) {
free(crop_para_->out_shape_);
crop_para_->out_shape_ = nullptr;
}
}
} // namespace mindspore::kernel

+ 2
- 1
mindspore/lite/src/runtime/kernel/arm/base/crop_base.h View File

@@ -35,9 +35,10 @@ class CropBaseCPUKernel : public LiteKernel {
int Init() override;
int ReSize() override;
int Run() override { return 0; }
void FreeTmpBuffer();

protected:
std::vector<int> input_shape_;
std::vector<int> output_shape_;
CropParameter *crop_para_;
void PadOffset(int input_dim, CropParameter *crop_para);
};


+ 0
- 1
mindspore/lite/src/runtime/kernel/arm/fp16/crop_fp16.cc View File

@@ -57,7 +57,6 @@ int CropFp16CPUKernel::Run() {
if (ret != RET_OK) {
MS_LOG(ERROR) << "ParallelLaunch failed: " << ret;
}
FreeTmpBuffer();
return ret;
}



+ 0
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/crop_fp32.cc View File

@@ -67,7 +67,6 @@ int CropCPUKernel::Run() {
MS_LOG(ERROR) << "Crop launch fail!ret: " << ret;
return RET_ERROR;
}
FreeTmpBuffer();
return RET_OK;
}



+ 0
- 12
mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.cc View File

@@ -49,18 +49,6 @@ int CropInt8CPUKernel::Init() {
return ReSize();
}

CropInt8CPUKernel::~CropInt8CPUKernel() {
if (crop_para_->in_shape_ != nullptr) {
free(const_cast<int *>(crop_para_->in_shape_));
crop_para_->in_shape_ = nullptr;
}

if (crop_para_->out_shape_ != nullptr) {
free(const_cast<int *>(crop_para_->out_shape_));
crop_para_->out_shape_ = nullptr;
}
}

int CropInt8CPUKernel::ReSize() { return CropBaseCPUKernel::ReSize(); }

int CropInt8CPUKernel::Run() {


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.h View File

@@ -32,7 +32,7 @@ class CropInt8CPUKernel : public CropBaseCPUKernel {
CropInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const mindspore::lite::InnerContext *ctx)
: CropBaseCPUKernel(parameter, inputs, outputs, ctx) {}
~CropInt8CPUKernel();
~CropInt8CPUKernel() = default;

int Init() override;
int ReSize() override;


+ 1
- 6
mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc View File

@@ -506,12 +506,10 @@ kernel::LiteKernel *OpenCLConv2DCreator(const std::vector<lite::Tensor *> &input

// case 3: common conv2d
kernel::OpenCLKernel *kernel;
OpParameter *real_param;
bool infer_shape_done = opParameter->infer_flag_;
if (infer_shape_done && UseFcReplaceConv(inputs, outputs, conv_param)) {
auto *fc_param = CreateFcParam(conv_param, inputs);
kernel = new (std::nothrow) FullConnectionOpenCLKernel(fc_param, inputs, outputs, ctx);
real_param = fc_param;
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create FullConnection kernel failed.";
free(fc_param);
@@ -529,7 +527,6 @@ kernel::LiteKernel *OpenCLConv2DCreator(const std::vector<lite::Tensor *> &input
} else {
kernel = new (std::nothrow) Conv2DOpenCLKernel(reinterpret_cast<OpParameter *>(conv_param), inputs, outputs, ctx);
}
real_param = reinterpret_cast<OpParameter *>(conv_param);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create Convolution kernel failed.";
free(conv_param);
@@ -540,11 +537,9 @@ kernel::LiteKernel *OpenCLConv2DCreator(const std::vector<lite::Tensor *> &input
MS_LOG(WARNING) << "kernel don't infer shape yet!";
return kernel;
}
int ret = kernel->CheckSpecs();
if (ret != mindspore::lite::RET_OK) {
if (kernel->CheckSpecs() != RET_OK || kernel->OpenCLKernel::CheckSpecs() != RET_OK) {
MS_LOG(ERROR) << "Init Convolution kernel failed.";
delete kernel;
free(real_param);
return nullptr;
}
return kernel;


+ 1
- 2
mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc View File

@@ -243,8 +243,7 @@ kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vector<lite::Tensor *>
MS_LOG(WARNING) << "kernel don't infer shape yet!";
return kernel;
}
auto ret = kernel->CheckSpecs();
if (ret != RET_OK) {
if (kernel->CheckSpecs() != RET_OK || kernel->OpenCLKernel::CheckSpecs() != RET_OK) {
MS_LOG(ERROR) << "Check " << opParameter->name_ << " specification failed!";
delete kernel;
return nullptr;


+ 7
- 0
mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc View File

@@ -406,6 +406,13 @@ int OpenCLKernel::CheckSpecs() {
return RET_ERROR;
}
}
if (in_tensors_.size() > 0) {
if (in_tensors_[0]->data_type() != kNumberTypeFloat32 && in_tensors_[0]->data_type() != kNumberTypeFloat16 &&
in_tensors_[0]->data_type() != kNumberTypeInt32) {
MS_LOG(WARNING) << "Unsupported data type: " << in_tensors_[0]->data_type();
return RET_ERROR;
}
}
return RET_OK;
}
} // namespace mindspore::kernel

+ 1
- 1
mindspore/lite/src/tensor.cc View File

@@ -30,7 +30,7 @@ Tensor::Tensor(const TypeId data_type, std::vector<int> shape, const schema::For
: data_type_(data_type), shape_(std::move(shape)), format_(format), category_(category) {}

Tensor::Tensor(const std::string &name, enum TypeId type, const std::vector<int32_t> &shape, const void *data)
: tensor_name_(name), data_type_(type), shape_(std::move(shape)) {
: tensor_name_(name), data_type_(type), shape_(std::move(shape)), category_(VAR) {
data_ = const_cast<void *>(data);
}



+ 1
- 1
mindspore/lite/tools/common/protobuf_utils.cc View File

@@ -91,7 +91,7 @@ STATUS ReadProtoFromBinaryFile(const char *file, google::protobuf::Message *mess
fs.close();

if (!success) {
MS_LOG(ERROR) << "Parse " << file << " failed.";
MS_LOG(DEBUG) << "Parse " << file << " failed.";
return RET_ERROR;
}



+ 1
- 1
mindspore/lite/tools/converter/quantizer/quantize_util.cc View File

@@ -940,7 +940,7 @@ STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int
return RET_OK;
}

void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, float *raw_datas,
void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, const float *raw_datas,
bool channel_at_first, float *desired_max, float *desired_min) {
float min = FLT_MAX;
float max = -FLT_MAX;


+ 1
- 1
mindspore/lite/tools/converter/quantizer/quantize_util.h View File

@@ -110,7 +110,7 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc

STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int new_size);

void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, float *raw_datas,
void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, const float *raw_datas,
bool channel_at_first, float *desired_max, float *desired_min);

template <typename T>


+ 1
- 1
mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc View File

@@ -226,7 +226,7 @@ void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const
MS_LOG(ERROR) << "memcpy_s error:" << ret;
return;
}
new_weight_tensor->set_tensor_addr(temp_weight_data);
new_weight_tensor->SetTensorData(temp_weight_data, new_weight_tensor->tensor_size());
CalNewWeightTensor(conv_node, new_weight_tensor, kernel_num, trans_scale);
float *bias_data = nullptr;
// conv has bias,bias_flag true


Loading…
Cancel
Save