Browse Source

Put the kernel graph cut according to the partial in the same group.

tags/v1.6.0
gaoyong10 4 years ago
parent
commit
3e7067d284
5 changed files with 9 additions and 43 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h
  2. +0
    -33
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
  3. +1
    -5
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.h
  4. +2
    -1
      mindspore/ccsrc/runtime/framework/control_node_parser.cc
  5. +5
    -3
      mindspore/ccsrc/vm/backend.cc

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h View File

@@ -102,7 +102,7 @@ class GpuKernel : public KernelMod {
if ((addr_list[index] == nullptr) || (addr_list[index]->addr == nullptr) || (addr_list[index]->size == 0)) {
auto kernel_node = kernel_node_.lock();
const std::string &prim_name = AnfAlgo::GetCNodeName(kernel_node);
const std::string &prim_name = (kernel_node == nullptr ? "" : AnfAlgo::GetCNodeName(kernel_node));
MS_LOG(EXCEPTION) << "The device address is empty, address index: " << index << ", op name is: " << prim_name;
}


+ 0
- 33
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -397,39 +397,6 @@ size_t AnfRuntimeAlgorithm::GetOutputNumByAbstract(const AbstractBasePtr &node_a
return result;
}

std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputByCallNode(const KernelWithIndex &output_with_index) {
MS_EXCEPTION_IF_NULL(output_with_index.first);
auto node_abstract = output_with_index.first->abstract();
MS_EXCEPTION_IF_NULL(node_abstract);
if (!node_abstract->isa<abstract::AbstractTuple>()) {
return {output_with_index};
}

auto tuple_abstract = node_abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
const auto &sub_abstracts = tuple_abstract->elements();
if (GetOutputNumByAbstract(tuple_abstract) <= output_with_index.second) {
MS_LOG(EXCEPTION) << "Invalid index:" << output_with_index.second
<< "for node:" << output_with_index.first->DebugString();
}

// There may be tuples in the output of the call node, these outputs will be all numbered, so it is necessary
// to count the number of outputs before the target in order to accurately obtain its number.
size_t pre_output_num = 0;
for (size_t i = 0; i < output_with_index.second; ++i) {
MS_EXCEPTION_IF_NULL(sub_abstracts[i]);
pre_output_num += GetOutputNumByAbstract(sub_abstracts[i]);
}

MS_EXCEPTION_IF_NULL(sub_abstracts[output_with_index.second]);
size_t output_num = GetOutputNumByAbstract(sub_abstracts[output_with_index.second]);
std::vector<KernelWithIndex> results;
for (size_t i = 0; i < output_num; ++i) {
results.emplace_back(output_with_index.first, pre_output_num + i);
}
return results;
}

std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node) {
auto ret = GetAllOutputWithIndexInner(node);
std::map<AnfNodePtr, size_t> value_node_index;


+ 1
- 5
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h View File

@@ -334,14 +334,10 @@ class AnfRuntimeAlgorithm {
static void CacheAddrForGraph(const KernelGraphPtr &kernel_graph);
static void CacheAddrForKernel(const AnfNodePtr &node, kernel::KernelMod *kernel_mod);
static void CacheAddrForAtomicClean(const AnfNodePtr &node, kernel::KernelMod *kernel_mod);
// Check whether node is a call node, there are two types of call nodes:
// 1. First input of node is a cnode.
// 2. First input of node is a funcgraph value node.
// Check whether node is a call node, call nodes are those cnodes whose first input is not primitive node.
static bool IsCallNode(const AnfNodePtr &node);
// Get the output number according to abstract, when there is a tuple in abstract, it needs to get recursively.
static size_t GetOutputNumByAbstract(const AbstractBasePtr &node_abstract);
// Fetch all outputs of call node.
static std::vector<KernelWithIndex> GetAllOutputByCallNode(const KernelWithIndex &output_with_index);
// Get attr groups
static int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index);



+ 2
- 1
mindspore/ccsrc/runtime/framework/control_node_parser.cc View File

@@ -1788,7 +1788,8 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte

// Collect outputs in group.
for (const auto &backend_to_front : kernel_graph->graph_output_map()) {
if (HasAbstractMonad(backend_to_front.second.first) || HasAbstractMonad(backend_to_front.first.first)) {
if (HasAbstractMonad(backend_to_front.second.first) || HasAbstractMonad(backend_to_front.first.first) ||
AnfAlgo::CheckPrimitiveType(backend_to_front.second.first, prim::kPrimPartial)) {
continue;
}
MS_LOG(DEBUG) << "Kernel graph:" << kernel_graph->ToString()


+ 5
- 3
mindspore/ccsrc/vm/backend.cc View File

@@ -541,9 +541,11 @@ void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_mu
MS_EXCEPTION_IF_NULL(cut_node);
MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->DebugString();
control_nodes_.push_back(cut_node);
const auto &func_graph = cut_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
func_graph_to_kernel_graph_ids_[func_graph].emplace_back(std::vector<GraphId>());
if (AnfAlgo::IsCallNode(cut_node)) {
const auto &func_graph = cut_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
func_graph_to_kernel_graph_ids_[func_graph].emplace_back(std::vector<GraphId>());
}
}
}



Loading…
Cancel
Save