Browse Source

fix review comments

tags/v1.1.0
zhanghaibo5 5 years ago
parent
commit
5296c25190
7 changed files with 38 additions and 24 deletions
  1. +7
    -6
      mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc
  2. +6
    -1
      mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.cc
  3. +5
    -2
      mindspore/lite/src/runtime/kernel/arm/int8/sub_int8.cc
  4. +1
    -1
      mindspore/lite/src/sub_graph_kernel.cc
  5. +3
    -2
      mindspore/lite/src/train/train_model.cc
  6. +11
    -11
      mindspore/lite/src/train/train_populate_parameter.cc
  7. +5
    -1
      mindspore/lite/src/train/train_session.cc

+ 7
- 6
mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc View File

@@ -38,11 +38,11 @@ int CropBaseCPUKernel::ReSize() {
if (crop_para_->in_shape_ == nullptr) {
MS_LOG(ERROR) << "in_shape_ is nullptr";
return RET_ERROR;
} else {
memcpy(reinterpret_cast<void *>(const_cast<int *>(crop_para_->in_shape_)), input_shape.data(),
sizeof(int) * input_dim);
}

memcpy(reinterpret_cast<void *>(const_cast<int *>(crop_para_->in_shape_)), input_shape.data(),
sizeof(int) * input_dim);

auto *out_tensor = out_tensors_.at(kOutputIndex);
auto output_shape = out_tensor->shape();
size_t output_dim = output_shape.size();
@@ -51,10 +51,11 @@ int CropBaseCPUKernel::ReSize() {
if (crop_para_->out_shape_ == nullptr) {
MS_LOG(ERROR) << "out_shape_ is nullptr";
return RET_ERROR;
} else {
memcpy(reinterpret_cast<void *>(const_cast<int *>(crop_para_->out_shape_)), output_shape.data(),
sizeof(int) * output_dim);
}

memcpy(reinterpret_cast<void *>(const_cast<int *>(crop_para_->out_shape_)), output_shape.data(),
sizeof(int) * output_dim);

MS_ASSERT(input_dim <= CROP_OFFSET_MAX_SIZE);
crop_para_->input_dim_ = input_dim;
PadOffset(input_dim, crop_para_);


+ 6
- 1
mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.cc View File

@@ -51,9 +51,14 @@ int BiasAddInt8CPUKernel::Run() {
auto out = reinterpret_cast<int8_t *>(out_tensors_.at(0)->MutableData());
size_t data_size = in_tensors_.at(0)->ElementsNum();
auto tile_in = static_cast<int8_t *>(ctx_->allocator->Malloc(data_size));
if (tile_in == nullptr) {
MS_LOG(ERROR) << "Failed to malloc momery";
return NNACL_ERR;
}
auto tile_bias = static_cast<int8_t *>(ctx_->allocator->Malloc(data_size));
if (tile_in == nullptr || tile_bias == nullptr) {
if (tile_bias == nullptr) {
MS_LOG(ERROR) << "Failed to malloc momery";
ctx_->allocator->Free(tile_in);
return NNACL_ERR;
}
BroadcastAddInt8(in, bias, tile_in, tile_bias, out, data_size,


+ 5
- 2
mindspore/lite/src/runtime/kernel/arm/int8/sub_int8.cc View File

@@ -132,11 +132,14 @@ int SubInt8CPUKernel::Run() {
tile_para.out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
}
tile0_data_ = static_cast<int8_t *>(context_->allocator->Malloc(out_tensors_.at(0)->Size()));
if (tile0_data_ == nullptr) {
MS_LOG(ERROR) << "malloc memroy fail!";
return RET_ERROR;
}
tile1_data_ = static_cast<int8_t *>(context_->allocator->Malloc(out_tensors_.at(0)->Size()));
if (tile0_data_ == nullptr || tile1_data_ == nullptr) {
if (tile1_data_ == nullptr) {
MS_LOG(ERROR) << "malloc memroy fail!";
context_->allocator->Free(tile0_data_);
context_->allocator->Free(tile1_data_);
return RET_ERROR;
}
TileDimensionsUint8(static_cast<uint8_t *>(in_tensors_.at(0)->MutableData()),


+ 1
- 1
mindspore/lite/src/sub_graph_kernel.cc View File

@@ -239,7 +239,7 @@ int CpuFp16SubGraph::PostProcess() {
tensor->set_data(nullptr);
tensor->set_data_type(TypeId::kNumberTypeFloat32);
auto ret = tensor->MallocData();
if (RET_OK != ret) {
if (ret != RET_OK) {
MS_LOG(ERROR) << "malloc data failed";
if (this->context_ != nullptr && this->context_->allocator != nullptr) {
this->context_->allocator->Free(float16_data);


+ 3
- 2
mindspore/lite/src/train/train_model.cc View File

@@ -40,14 +40,13 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) {
model->buf = reinterpret_cast<char *>(malloc(size));
if (model->buf == nullptr) {
delete model;
MS_LOG(ERROR) << "new inner model buf fail!";
MS_LOG(ERROR) << "malloc inner model buf fail!";
return nullptr;
}
memcpy(model->buf, model_buf, size);
model->buf_size_ = size;
auto meta_graph = schema::GetMetaGraph(model->buf);
if (meta_graph == nullptr) {
free(model->buf);
delete model;
MS_LOG(ERROR) << "meta_graph is nullptr!";
return nullptr;
@@ -73,6 +72,7 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) {
int ret = MetaGraphMappingSubGraph(*meta_graph, model);
if (ret != RET_OK) {
MS_LOG(ERROR) << "converter old version model wrong.";
delete model;
return nullptr;
}
} else {
@@ -83,6 +83,7 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) {
int ret = ConvertSubGraph(*sub_graph, model);
if (ret != RET_OK) {
MS_LOG(ERROR) << "converter subgraph wrong.";
delete model;
return nullptr;
}
}


+ 11
- 11
mindspore/lite/src/train/train_populate_parameter.cc View File

@@ -50,7 +50,7 @@ OpParameter *DefaultPopulateParameter(const mindspore::lite::PrimitiveC *primiti

OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "new Param for primitive failed.";
MS_LOG(ERROR) << "malloc Param for primitive failed.";
return nullptr;
}

@@ -65,7 +65,7 @@ OpParameter *PopulateApplyMomentumParameter(const mindspore::lite::PrimitiveC *p
}
ApplyMomentumParameter *p = reinterpret_cast<ApplyMomentumParameter *>(malloc(sizeof(ApplyMomentumParameter)));
if (p == nullptr) {
MS_LOG(ERROR) << "new ApplyMomentumParameter failed.";
MS_LOG(ERROR) << "malloc ApplyMomentumParameter failed.";
return nullptr;
}
p->op_parameter_.type_ = primitive->Type();
@@ -128,7 +128,7 @@ OpParameter *PopulateSgdParameter(const mindspore::lite::PrimitiveC *primitive)
}
SgdParameter *p = reinterpret_cast<SgdParameter *>(malloc(sizeof(SgdParameter)));
if (p == nullptr) {
MS_LOG(ERROR) << "new SgdParameter failed.";
MS_LOG(ERROR) << "malloc SgdParameter failed.";
return nullptr;
}
p->op_parameter_.type_ = primitive->Type();
@@ -150,7 +150,7 @@ OpParameter *PopulateSoftmaxCrossEntropyParameter(const mindspore::lite::Primiti
SoftmaxCrossEntropyParameter *sce_param =
reinterpret_cast<SoftmaxCrossEntropyParameter *>(malloc(sizeof(SoftmaxCrossEntropyParameter)));
if (sce_param == nullptr) {
MS_LOG(ERROR) << "new SoftmaxCrossEntropyParameter failed.";
MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed.";
return nullptr;
}
sce_param->op_parameter_.type_ = primitive->Type();
@@ -164,7 +164,7 @@ OpParameter *PopulatePoolingGradParameter(const mindspore::lite::PrimitiveC *pri
}
PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
if (pooling_param == nullptr) {
MS_LOG(ERROR) << "new PoolingParameter failed.";
MS_LOG(ERROR) << "malloc PoolingParameter failed.";
return nullptr;
}
pooling_param->op_parameter_.type_ = primitive->Type();
@@ -217,7 +217,7 @@ OpParameter *PopulateActivationGradParameter(const mindspore::lite::PrimitiveC *

ActivationParameter *act_param = reinterpret_cast<ActivationParameter *>(malloc(sizeof(ActivationParameter)));
if (act_param == nullptr) {
MS_LOG(ERROR) << "new ActivationParameter failed.";
MS_LOG(ERROR) << "malloc ActivationParameter failed.";
return nullptr;
}
act_param->op_parameter_.type_ = primitive->Type();
@@ -236,7 +236,7 @@ OpParameter *PopulateConvolutionGradFilterParameter(const mindspore::lite::Primi

ConvParameter *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "new Param for conv grad filter failed.";
MS_LOG(ERROR) << "malloc Param for conv grad filter failed.";
return nullptr;
}
param->op_parameter_.type_ = primitive->Type();
@@ -277,7 +277,7 @@ OpParameter *PopulateConvolutionGradInputParameter(const mindspore::lite::Primit

ConvParameter *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "new Param for conv grad filter failed.";
MS_LOG(ERROR) << "malloc Param for conv grad filter failed.";
return nullptr;
}
param->op_parameter_.type_ = primitive->Type();
@@ -359,7 +359,7 @@ OpParameter *PopulatePowerGradParameter(const mindspore::lite::PrimitiveC *primi

PowerParameter *power_param = reinterpret_cast<PowerParameter *>(malloc(sizeof(PowerParameter)));
if (power_param == nullptr) {
MS_LOG(ERROR) << "new PowerParameter failed.";
MS_LOG(ERROR) << "malloc PowerParameter failed.";
return nullptr;
}
power_param->op_parameter_.type_ = primitive->Type();
@@ -378,7 +378,7 @@ OpParameter *PopulateBiasGradParameter(const mindspore::lite::PrimitiveC *primit

ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "new ArithmeticParameter failed.";
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
arithmetic_param->op_parameter_.type_ = primitive->Type();
@@ -393,7 +393,7 @@ OpParameter *PopulateBNGradParameter(const mindspore::lite::PrimitiveC *primitiv

BNGradParameter *bnGrad_param = reinterpret_cast<BNGradParameter *>(malloc(sizeof(BNGradParameter)));
if (bnGrad_param == nullptr) {
MS_LOG(ERROR) << "new BNGradParameter failed.";
MS_LOG(ERROR) << "malloc BNGradParameter failed.";
return nullptr;
}
bnGrad_param->op_parameter_.type_ = primitive->Type();


+ 5
- 1
mindspore/lite/src/train/train_session.cc View File

@@ -279,7 +279,11 @@ bool TrainSession::IsLossKernel(const kernel::LiteKernel *kernel) {
} // namespace lite

session::TrainSession *session::TrainSession::CreateSession(lite::Context *context) {
auto session = new lite::TrainSession();
auto session = new (std::nothrow) lite::TrainSession();
if (session == nullptr) {
MS_LOG(ERROR) << "create session failed";
return nullptr;
}
auto ret = session->Init(context);
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "init sesssion failed";


Loading…
Cancel
Save