|
|
|
@@ -656,6 +656,11 @@ void TrainSession::CompileEvalOutputs() { |
|
|
|
if (IsLossKernel(kernel) && !(IsGradKernel(kernel))) { |
|
|
|
for (auto in_kernel : kernel->in_kernels()) { |
|
|
|
if (IsLossKernel(in_kernel) || IsGradKernel(in_kernel)) continue; |
|
|
|
bool is_loss = false; |
|
|
|
for (auto in_in_kernel : in_kernel->in_kernels()) { |
|
|
|
if (IsLossKernel(in_in_kernel)) is_loss = true; |
|
|
|
} |
|
|
|
if (is_loss) continue; |
|
|
|
// insert if not already in |
|
|
|
if (eval_output_node_map_.find(in_kernel->name()) == eval_output_node_map_.end()) { |
|
|
|
auto *ms_tensor = in_kernel->out_tensors().at(0); |
|
|
|
@@ -726,13 +731,53 @@ void TrainSession::BuildInferenceKernelsRecursive(kernel::LiteKernel *kernel, st |
|
|
|
|
|
|
|
void TrainSession::CompileTrainKernels() { |
|
|
|
train_kernels_.clear(); |
|
|
|
std::vector<kernel::LiteKernel *> train_kernels; |
|
|
|
for (auto ori_kernel : kernels_) { |
|
|
|
if (ori_kernel->subgraph_type() == kernel::kNotSubGraph) { |
|
|
|
train_kernels_.push_back(ori_kernel); |
|
|
|
train_kernels.push_back(ori_kernel); |
|
|
|
} else { |
|
|
|
auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(ori_kernel); |
|
|
|
for (auto kernel : sub_graph->nodes()) { |
|
|
|
train_kernels_.push_back(kernel); |
|
|
|
std::copy(sub_graph->nodes().begin(), sub_graph->nodes().end(), std::back_inserter(train_kernels)); |
|
|
|
} |
|
|
|
} |
|
|
|
// For LSTM GPU operators are synchronized internally hence we need to add sync mechanizm to execution graph |
|
|
|
for (auto k : train_kernels) { |
|
|
|
if (k->type() == schema::PrimitiveType_LSTMGradWeight) { |
|
|
|
// Find PrimitiveType_LSTMGradData that matches this PrimitiveType_LSTMGradWeight |
|
|
|
for (auto mk : train_kernels) { |
|
|
|
if (mk->type() == schema::PrimitiveType_LSTMGradData) { |
|
|
|
if (k->in_tensors().at(C2NUM)->tensor_name() == mk->in_tensors().at(0)->tensor_name()) { |
|
|
|
mk->AddOutKernel(k); |
|
|
|
k->AddInKernel(mk); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
std::queue<kernel::LiteKernel *> queue; |
|
|
|
for (auto k : train_kernels) { |
|
|
|
auto in_kernels = k->in_kernels(); |
|
|
|
if (in_kernels.size() == 0) { |
|
|
|
queue.push(k); |
|
|
|
} |
|
|
|
} |
|
|
|
std::unordered_map<kernel::LiteKernel *, int> map; |
|
|
|
while (queue.size()) { |
|
|
|
// pop first element |
|
|
|
auto k = queue.front(); |
|
|
|
train_kernels_.push_back(k); |
|
|
|
queue.pop(); |
|
|
|
for (auto &ok : k->out_kernels()) { |
|
|
|
auto cnt_iter = map.find(ok); |
|
|
|
if (cnt_iter == map.end()) { |
|
|
|
int ref_cnt = ok->in_kernels().size(); |
|
|
|
map[ok] = ref_cnt; |
|
|
|
} |
|
|
|
cnt_iter = map.find(ok); |
|
|
|
auto ref_cnt = cnt_iter->second - 1; |
|
|
|
map[ok] = ref_cnt; |
|
|
|
if (ref_cnt <= 0) { |
|
|
|
queue.push(cnt_iter->first); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|