Browse Source

!8895 [MSLITE]improve train code

From: @guohonhzilonghw
Reviewed-by: @ddwsky,@zhanghaibo5
Signed-off-by: @zhanghaibo5
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
406a3dbdb3
3 changed files with 33 additions and 23 deletions
  1. +3
    -3
      mindspore/lite/src/train/loss_kernel.h
  2. +6
    -5
      mindspore/lite/src/train/train_model.cc
  3. +24
    -15
      mindspore/lite/src/train/train_session.cc

+ 3
- 3
mindspore/lite/src/train/loss_kernel.h View File

@@ -22,9 +22,9 @@ namespace mindspore::kernel {
class LossKernel : public LiteKernel { class LossKernel : public LiteKernel {
public: public:
LossKernel() = default; LossKernel() = default;
explicit LossKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const lite::PrimitiveC *primitive)
LossKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {} : LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~LossKernel() = default; ~LossKernel() = default;
}; };


+ 6
- 5
mindspore/lite/src/train/train_model.cc View File

@@ -49,12 +49,14 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) {
const void *meta_graph = GetMetaGraphByVerison(model->buf, schema_version); const void *meta_graph = GetMetaGraphByVerison(model->buf, schema_version);
if (meta_graph == nullptr) { if (meta_graph == nullptr) {
MS_LOG(ERROR) << "meta_graph is nullptr!"; MS_LOG(ERROR) << "meta_graph is nullptr!";
free(model->buf);
delete (model); delete (model);
return nullptr; return nullptr;
} }


int status = GenerateModelByVersion(meta_graph, model, schema_version); int status = GenerateModelByVersion(meta_graph, model, schema_version);
if (status != RET_OK) { if (status != RET_OK) {
free(model->buf);
delete (model); delete (model);
MS_LOG(ERROR) << "fail to generate model"; MS_LOG(ERROR) << "fail to generate model";
return nullptr; return nullptr;
@@ -73,17 +75,16 @@ char *TrainModel::ExportBuf(char *buffer, size_t *len) const {
MS_LOG(ERROR) << "Model::Export is only available for Train Session"; MS_LOG(ERROR) << "Model::Export is only available for Train Session";
return nullptr; return nullptr;
} }

if (*len < buf_size_ && buffer != nullptr) { if (*len < buf_size_ && buffer != nullptr) {
MS_LOG(ERROR) << "Buffer is too small, Export Failed"; MS_LOG(ERROR) << "Buffer is too small, Export Failed";
return nullptr; return nullptr;
} }
if (buffer == nullptr) { if (buffer == nullptr) {
buffer = reinterpret_cast<char *>(malloc(buf_size_)); buffer = reinterpret_cast<char *>(malloc(buf_size_));
}
if (buffer == nullptr) {
MS_LOG(ERROR) << "allocated model buf fail!";
return nullptr;
if (buffer == nullptr) {
MS_LOG(ERROR) << "allocated model buf fail!";
return nullptr;
}
} }


memcpy(buffer, buf, buf_size_); memcpy(buffer, buf, buf_size_);


+ 24
- 15
mindspore/lite/src/train/train_session.cc View File

@@ -92,13 +92,22 @@ void TrainSession::AllocWorkSpace() {
int TrainSession::CompileGraph(lite::Model *model) { return lite::RET_ERROR; } int TrainSession::CompileGraph(lite::Model *model) { return lite::RET_ERROR; }


int TrainSession::CompileTrainGraph(mindspore::lite::TrainModel *model) { int TrainSession::CompileTrainGraph(mindspore::lite::TrainModel *model) {
if (model == nullptr) {
MS_LOG(ERROR) << "model is null";
return RET_ERROR;
}
model_ = model; model_ = model;

auto restore = ReplaceOps(); auto restore = ReplaceOps();
auto ret = lite::LiteSession::CompileGraph(model); auto ret = lite::LiteSession::CompileGraph(model);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Compile train graph failed";
return RET_ERROR;
}
orig_output_map_ = output_node_map_; orig_output_map_ = output_node_map_;
orig_output_tensor_map_ = output_tensor_map_; orig_output_tensor_map_ = output_tensor_map_;
for (auto inTensor : inputs_) inTensor->MutableData();
for (auto inTensor : inputs_) {
inTensor->MutableData();
}
RestoreOps(restore); RestoreOps(restore);
AllocWorkSpace(); AllocWorkSpace();
MarkOptimizedKernels(); MarkOptimizedKernels();
@@ -152,7 +161,7 @@ int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &a
int TrainSession::SaveToFile(const std::string &filename) const { int TrainSession::SaveToFile(const std::string &filename) const {
size_t fb_size = 0; size_t fb_size = 0;
auto *buf = reinterpret_cast<char *>(ExportToBuf(nullptr, &fb_size)); auto *buf = reinterpret_cast<char *>(ExportToBuf(nullptr, &fb_size));
if (buf == NULL) {
if (buf == nullptr) {
MS_LOG(ERROR) << "Could not Export Trained model"; MS_LOG(ERROR) << "Could not Export Trained model";
return lite::RET_NULL_PTR; return lite::RET_NULL_PTR;
} }
@@ -212,7 +221,7 @@ int TrainSession::Train() {
} }


void TrainSession::UpdateOutputMapByLossKernel(const kernel::LiteKernel *kernel) { void TrainSession::UpdateOutputMapByLossKernel(const kernel::LiteKernel *kernel) {
if (IsLossKernel(kernel)) {
if (kernel != nullptr && IsLossKernel(kernel)) {
auto *ms_tensor = kernel->out_tensors().at(0); auto *ms_tensor = kernel->out_tensors().at(0);
if (ms_tensor != nullptr) { if (ms_tensor != nullptr) {
(void)ms_tensor->MutableData(); (void)ms_tensor->MutableData();
@@ -226,7 +235,7 @@ void TrainSession::UpdateOutputMapByLossKernel(const kernel::LiteKernel *kernel)
} }


void TrainSession::UpdateOutputMapByInKernel(const kernel::LiteKernel *kernel) { void TrainSession::UpdateOutputMapByInKernel(const kernel::LiteKernel *kernel) {
if (IsLossKernel(kernel)) {
if (kernel != nullptr && IsLossKernel(kernel)) {
for (auto in_kernel : kernel->in_kernels()) { for (auto in_kernel : kernel->in_kernels()) {
if (output_node_map_.find(in_kernel->name()) == output_node_map_.end()) { if (output_node_map_.find(in_kernel->name()) == output_node_map_.end()) {
auto *ms_tensor = in_kernel->out_tensors().at(0); auto *ms_tensor = in_kernel->out_tensors().at(0);
@@ -304,9 +313,9 @@ void TrainSession::BuildInferenceKernelsMap() {
} }
} else { } else {
auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel); auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
for (auto sb_kernel : sub_graph->nodes()) {
if (IsLossKernel(sb_kernel)) { // For each loss in the system add backward tree
for (auto in_node : sb_kernel->in_kernels()) {
for (auto sub_kernel : sub_graph->nodes()) {
if (IsLossKernel(sub_kernel)) { // For each loss in the system add backward tree
for (auto in_node : sub_kernel->in_kernels()) {
BuildInferenceKernelsRecursive(in_node, &req_kernels); BuildInferenceKernelsRecursive(in_node, &req_kernels);
} }
} }
@@ -357,9 +366,9 @@ void TrainSession::MarkOptimizedKernels() {
} }
} else { } else {
auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel); auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
for (auto sb_kernel : sub_graph->nodes()) {
if (IsOptimizer(sb_kernel)) {
std::copy(sb_kernel->in_tensors().begin(), sb_kernel->in_tensors().end(), std::back_inserter(ot));
for (auto sub_kernel : sub_graph->nodes()) {
if (IsOptimizer(sub_kernel)) {
std::copy(sub_kernel->in_tensors().begin(), sub_kernel->in_tensors().end(), std::back_inserter(ot));
} }
} }
} }
@@ -376,11 +385,11 @@ void TrainSession::MarkOptimizedKernels() {
} }
} else { } else {
auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel); auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
for (auto sb_kernel : sub_graph->nodes()) {
if (!IsOptimizer(sb_kernel)) {
for (auto it : sb_kernel->in_tensors()) {
for (auto sub_kernel : sub_graph->nodes()) {
if (!IsOptimizer(sub_kernel)) {
for (auto it : sub_kernel->in_tensors()) {
if (std::find(ot.begin(), ot.end(), it) != ot.end()) { if (std::find(ot.begin(), ot.end(), it) != ot.end()) {
sb_kernel->set_trainable(true);
sub_kernel->set_trainable(true);
break; break;
} }
} }


Loading…
Cancel
Save