|
|
|
@@ -21,9 +21,12 @@ |
|
|
|
#include <string> |
|
|
|
#include <algorithm> |
|
|
|
#include <numeric> |
|
|
|
#include <deque> |
|
|
|
#include <functional> |
|
|
|
#include "backend/session/session_basic.h" |
|
|
|
#include "backend/session/session_factory.h" |
|
|
|
#include "backend/optimizer/common/optimizer.h" |
|
|
|
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h" |
|
|
|
#include "cxx_api/factory.h" |
|
|
|
#include "vm/backend.h" |
|
|
|
#include "vm/transform.h" |
|
|
|
@@ -56,14 +59,7 @@ class MSTensorRef : public BaseRef { |
|
|
|
std::vector<MSTensor> res; |
|
|
|
if (utils::isa<VectorRef>(args)) { |
|
|
|
VectorRef args_vec = utils::cast<VectorRef>(args); |
|
|
|
for (size_t i = 0; i < args_vec.size(); ++i) { |
|
|
|
const auto &item = args_vec[i]; |
|
|
|
if (!utils::isa<MSTensorRef>(item)) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid item " << item.ToString() << " at index " << i; |
|
|
|
} |
|
|
|
auto wrapper = utils::cast<MSTensorRef>(item); |
|
|
|
res.push_back(wrapper.ms_tensor_); |
|
|
|
} |
|
|
|
res = ConvertTuple(args_vec); |
|
|
|
} else if (utils::isa<MSTensorRef>(args)) { |
|
|
|
auto wrapper = utils::cast<MSTensorRef>(args); |
|
|
|
res.push_back(wrapper.ms_tensor_); |
|
|
|
@@ -101,6 +97,25 @@ class MSTensorRef : public BaseRef { |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
static std::vector<MSTensor> ConvertTuple(const VectorRef &args) { |
|
|
|
std::vector<MSTensor> outs; |
|
|
|
for (size_t i = 0; i < args.size(); ++i) { |
|
|
|
const auto &item = args[i]; |
|
|
|
if (utils::isa<VectorRef>(item)) { |
|
|
|
VectorRef args_vec = utils::cast<VectorRef>(args); |
|
|
|
auto ret = ConvertTuple(args_vec); |
|
|
|
outs.insert(outs.end(), ret.begin(), ret.end()); |
|
|
|
} else if (utils::isa<MSTensorRef>(item)) { |
|
|
|
auto wrapper = utils::cast<MSTensorRef>(item); |
|
|
|
outs.push_back(wrapper.ms_tensor_); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid BaseRef " << args.ToString() |
|
|
|
<< " must be MSTensorRef or VectorRef{MSTensorRef...}"; |
|
|
|
} |
|
|
|
} |
|
|
|
return outs; |
|
|
|
} |
|
|
|
|
|
|
|
MSTensor ms_tensor_; |
|
|
|
}; |
|
|
|
|
|
|
|
@@ -114,7 +129,11 @@ class MultiGraphAclSession : public session::SessionBasic { |
|
|
|
void SetOptions(const std::shared_ptr<AclModelOptions> &options) { options_ = options; } |
|
|
|
|
|
|
|
private: |
|
|
|
VectorRef ConstructOutputRef(GraphId graph_id, std::deque<MSTensor> *out_tensors); |
|
|
|
VectorRef ConstructOutputRefByTupleNode(const CNodePtr &tuple_node, std::deque<MSTensor> *out_tensors); |
|
|
|
|
|
|
|
std::map<GraphId, GraphCell> graphs_ = {}; |
|
|
|
std::map<GraphId, KernelGraphPtr> kernel_graphs_ = {}; |
|
|
|
std::shared_ptr<AclModelOptions> options_ = nullptr; |
|
|
|
}; |
|
|
|
|
|
|
|
@@ -138,8 +157,16 @@ GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const |
|
|
|
std::shared_ptr<AclModelOptions> options_; |
|
|
|
}; |
|
|
|
MS_LOG(INFO) << "Start MultiGraph Compile."; |
|
|
|
auto kernel_graph = ConstructKernelGraph(lst, outputs, false); |
|
|
|
// construct kernel graph |
|
|
|
auto kernel_graph = SessionBasic::ConstructKernelGraph(lst, outputs, false); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>(); |
|
|
|
auto pm = std::make_shared<opt::PassManager>("310_multi_graph_pm"); |
|
|
|
pm->AddPass(std::make_shared<opt::InsertPlaceholderForDynamicRNN>()); |
|
|
|
optimizer->AddPassManager(pm); |
|
|
|
(void)optimizer->Optimize(kernel_graph); |
|
|
|
kernel_graph->SetExecOrderByDefault(); |
|
|
|
// concert to om data |
|
|
|
ModelConverter model_converter_; |
|
|
|
model_converter_.set_options(options_); |
|
|
|
FirstGraphModeGuard guard(options_); |
|
|
|
@@ -148,6 +175,7 @@ GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const |
|
|
|
MS_LOG(ERROR) << "Load MindIR failed."; |
|
|
|
return kMCFailed; |
|
|
|
} |
|
|
|
// load |
|
|
|
std::shared_ptr<Graph> graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(om_data, ModelType::kOM)); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto graph_cell = GraphCell(graph); |
|
|
|
@@ -156,6 +184,7 @@ GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const |
|
|
|
MS_LOG(EXCEPTION) << "Load failed."; |
|
|
|
} |
|
|
|
graphs_[kernel_graph->graph_id()] = graph_cell; |
|
|
|
kernel_graphs_[kernel_graph->graph_id()] = kernel_graph; |
|
|
|
MS_LOG(INFO) << "Mulit graph compile success, graph id " << kernel_graph->graph_id(); |
|
|
|
return kernel_graph->graph_id(); |
|
|
|
} |
|
|
|
@@ -172,7 +201,61 @@ void MultiGraphAclSession::RunGraph(GraphId graph_id, const std::vector<MSTensor |
|
|
|
if (ret != kSuccess) { |
|
|
|
MS_LOG(EXCEPTION) << "Graph id " << graph_id << " run failed."; |
|
|
|
} |
|
|
|
(*outputs) = MSTensorRef::Convert(out_tensors); |
|
|
|
|
|
|
|
std::deque<MSTensor> out_tensors_deque(out_tensors.begin(), out_tensors.end()); |
|
|
|
(*outputs) = ConstructOutputRef(graph_id, &out_tensors_deque); |
|
|
|
} |
|
|
|
|
|
|
|
VectorRef MultiGraphAclSession::ConstructOutputRef(GraphId graph_id, std::deque<MSTensor> *out_tensors) { |
|
|
|
MS_EXCEPTION_IF_NULL(out_tensors); |
|
|
|
VectorRef outs; |
|
|
|
auto out_nodes = kernel_graphs_[graph_id]->outputs(); |
|
|
|
for (auto &out : out_nodes) { |
|
|
|
if (out_tensors->empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Can not find MSTensor for output node " << out->DebugString(); |
|
|
|
} |
|
|
|
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(out, 0); |
|
|
|
auto &anf_node = item_with_index.first; |
|
|
|
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) { |
|
|
|
auto cnode = anf_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors)); |
|
|
|
} else { |
|
|
|
outs.emplace_back(MSTensorRef(out_tensors->front())); |
|
|
|
out_tensors->pop_front(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (!out_tensors->empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Number of output size " << outs.size() << " but " << out_tensors->size() |
|
|
|
<< " MSTensor remained."; |
|
|
|
} |
|
|
|
|
|
|
|
return outs; |
|
|
|
} |
|
|
|
|
|
|
|
VectorRef MultiGraphAclSession::ConstructOutputRefByTupleNode(const CNodePtr &tuple_node, |
|
|
|
std::deque<MSTensor> *out_tensors) { |
|
|
|
MS_EXCEPTION_IF_NULL(out_tensors); |
|
|
|
VectorRef outs; |
|
|
|
for (size_t i = 1; i < tuple_node->inputs().size(); ++i) { |
|
|
|
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(tuple_node->input(i), 0); |
|
|
|
auto &anf_node = item_with_index.first; |
|
|
|
if (out_tensors->empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Can not find MSTensor for output node " << anf_node->DebugString(); |
|
|
|
} |
|
|
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) { |
|
|
|
auto cnode = anf_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors)); |
|
|
|
} else { |
|
|
|
outs.emplace_back(MSTensorRef(out_tensors->front())); |
|
|
|
out_tensors->pop_front(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return outs; |
|
|
|
} |
|
|
|
|
|
|
|
class AclBackend : public compile::MsBackend { |
|
|
|
|