|
|
|
@@ -92,13 +92,22 @@ void TrainSession::AllocWorkSpace() { |
|
|
|
int TrainSession::CompileGraph(lite::Model *model) { return lite::RET_ERROR; } |
|
|
|
|
|
|
|
int TrainSession::CompileTrainGraph(mindspore::lite::TrainModel *model) { |
|
|
|
if (model == nullptr) { |
|
|
|
MS_LOG(ERROR) << "model is null"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
model_ = model; |
|
|
|
|
|
|
|
auto restore = ReplaceOps(); |
|
|
|
auto ret = lite::LiteSession::CompileGraph(model); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Compile train graph failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
orig_output_map_ = output_node_map_; |
|
|
|
orig_output_tensor_map_ = output_tensor_map_; |
|
|
|
for (auto inTensor : inputs_) inTensor->MutableData(); |
|
|
|
for (auto inTensor : inputs_) { |
|
|
|
inTensor->MutableData(); |
|
|
|
} |
|
|
|
RestoreOps(restore); |
|
|
|
AllocWorkSpace(); |
|
|
|
MarkOptimizedKernels(); |
|
|
|
@@ -152,7 +161,7 @@ int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &a |
|
|
|
int TrainSession::SaveToFile(const std::string &filename) const { |
|
|
|
size_t fb_size = 0; |
|
|
|
auto *buf = reinterpret_cast<char *>(ExportToBuf(nullptr, &fb_size)); |
|
|
|
if (buf == NULL) { |
|
|
|
if (buf == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Could not Export Trained model"; |
|
|
|
return lite::RET_NULL_PTR; |
|
|
|
} |
|
|
|
@@ -212,7 +221,7 @@ int TrainSession::Train() { |
|
|
|
} |
|
|
|
|
|
|
|
void TrainSession::UpdateOutputMapByLossKernel(const kernel::LiteKernel *kernel) { |
|
|
|
if (IsLossKernel(kernel)) { |
|
|
|
if (kernel != nullptr && IsLossKernel(kernel)) { |
|
|
|
auto *ms_tensor = kernel->out_tensors().at(0); |
|
|
|
if (ms_tensor != nullptr) { |
|
|
|
(void)ms_tensor->MutableData(); |
|
|
|
@@ -226,7 +235,7 @@ void TrainSession::UpdateOutputMapByLossKernel(const kernel::LiteKernel *kernel) |
|
|
|
} |
|
|
|
|
|
|
|
void TrainSession::UpdateOutputMapByInKernel(const kernel::LiteKernel *kernel) { |
|
|
|
if (IsLossKernel(kernel)) { |
|
|
|
if (kernel != nullptr && IsLossKernel(kernel)) { |
|
|
|
for (auto in_kernel : kernel->in_kernels()) { |
|
|
|
if (output_node_map_.find(in_kernel->name()) == output_node_map_.end()) { |
|
|
|
auto *ms_tensor = in_kernel->out_tensors().at(0); |
|
|
|
@@ -304,9 +313,9 @@ void TrainSession::BuildInferenceKernelsMap() { |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel); |
|
|
|
for (auto sb_kernel : sub_graph->nodes()) { |
|
|
|
if (IsLossKernel(sb_kernel)) { // For each loss in the system add backward tree |
|
|
|
for (auto in_node : sb_kernel->in_kernels()) { |
|
|
|
for (auto sub_kernel : sub_graph->nodes()) { |
|
|
|
if (IsLossKernel(sub_kernel)) { // For each loss in the system add backward tree |
|
|
|
for (auto in_node : sub_kernel->in_kernels()) { |
|
|
|
BuildInferenceKernelsRecursive(in_node, &req_kernels); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -357,9 +366,9 @@ void TrainSession::MarkOptimizedKernels() { |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel); |
|
|
|
for (auto sb_kernel : sub_graph->nodes()) { |
|
|
|
if (IsOptimizer(sb_kernel)) { |
|
|
|
std::copy(sb_kernel->in_tensors().begin(), sb_kernel->in_tensors().end(), std::back_inserter(ot)); |
|
|
|
for (auto sub_kernel : sub_graph->nodes()) { |
|
|
|
if (IsOptimizer(sub_kernel)) { |
|
|
|
std::copy(sub_kernel->in_tensors().begin(), sub_kernel->in_tensors().end(), std::back_inserter(ot)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -376,11 +385,11 @@ void TrainSession::MarkOptimizedKernels() { |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel); |
|
|
|
for (auto sb_kernel : sub_graph->nodes()) { |
|
|
|
if (!IsOptimizer(sb_kernel)) { |
|
|
|
for (auto it : sb_kernel->in_tensors()) { |
|
|
|
for (auto sub_kernel : sub_graph->nodes()) { |
|
|
|
if (!IsOptimizer(sub_kernel)) { |
|
|
|
for (auto it : sub_kernel->in_tensors()) { |
|
|
|
if (std::find(ot.begin(), ot.end(), it) != ot.end()) { |
|
|
|
sb_kernel->set_trainable(true); |
|
|
|
sub_kernel->set_trainable(true); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
|