|
|
|
@@ -1,5 +1,5 @@ |
|
|
|
/** |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
@@ -166,6 +166,63 @@ struct CallInfo { |
|
|
|
AnfNodePtr label_param = nullptr; |
|
|
|
}; |
|
|
|
|
|
|
|
// |
|
|
|
// ParameterPool cache parameters by its abstract, so that we can reuse |
|
|
|
// parameter with same abstract to store return values. |
|
|
|
// |
|
|
|
class ParameterPool { |
|
|
|
public: |
|
|
|
explicit ParameterPool(const KernelGraphPtr &top_graph) : top_graph_(top_graph) {} |
|
|
|
~ParameterPool() = default; |
|
|
|
|
|
|
|
// Create or get a parameter from pool with the given abstract. |
|
|
|
AnfNodePtr GetParameter(const abstract::AbstractBasePtr &abs) { |
|
|
|
// Find parameter in pool by the given abstract. |
|
|
|
auto iter = std::find_if(paras_.begin(), paras_.end(), [&abs](auto ¶) { |
|
|
|
auto para_abs = para->abstract(); |
|
|
|
// Reuse output parameter with compatible abstract. |
|
|
|
return IsCompatible(abs, para_abs); |
|
|
|
}); |
|
|
|
// Return the parameter if found. |
|
|
|
if (iter != paras_.end()) { |
|
|
|
return *iter; |
|
|
|
} |
|
|
|
// If parameter not found with the given abstract, create a new one. |
|
|
|
auto para = top_graph_->NewParameter(abs); |
|
|
|
auto out_para = top_graph_->TransTupleToMakeTuple(para); |
|
|
|
// This is required, so that device memory can be allocated for it. |
|
|
|
top_graph_->AddChildGraphResult(out_para); |
|
|
|
// Save new para to pool. |
|
|
|
paras_.push_back(out_para); |
|
|
|
return out_para; |
|
|
|
} |
|
|
|
|
|
|
|
protected: |
|
|
|
// Check if one abstract is compatible with another abstract. |
|
|
|
static bool IsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2) { |
|
|
|
if (a1 == nullptr || a2 == nullptr) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (a1->isa<abstract::AbstractTensor>() && a2->isa<abstract::AbstractTensor>()) { |
|
|
|
// This make AbstractRef compatible with AbstractTensor. |
|
|
|
auto &t1 = static_cast<abstract::AbstractTensor &>(*a1); |
|
|
|
auto &t2 = static_cast<abstract::AbstractTensor &>(*a2); |
|
|
|
return t1 == t2; |
|
|
|
} |
|
|
|
return *a1 == *a2; |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
// The top graph. |
|
|
|
const KernelGraphPtr &top_graph_; |
|
|
|
|
|
|
|
// Cached parameters. |
|
|
|
std::vector<AnfNodePtr> paras_; |
|
|
|
}; |
|
|
|
|
|
|
|
// |
|
|
|
// Base class for context. |
|
|
|
// |
|
|
|
class BaseContext { |
|
|
|
public: |
|
|
|
void MarkVisited(const KernelGraphPtr &kg) { visited_graphs_.insert(kg); } |
|
|
|
@@ -185,7 +242,7 @@ class BaseContext { |
|
|
|
// |
|
|
|
class AscendAutoMonadContext : public BaseContext { |
|
|
|
public: |
|
|
|
explicit AscendAutoMonadContext(const KernelGraphPtr &kg) : top_graph_(kg) {} |
|
|
|
explicit AscendAutoMonadContext(const KernelGraphPtr &kg) : top_graph_(kg), param_pool_(kg) {} |
|
|
|
~AscendAutoMonadContext() = default; |
|
|
|
|
|
|
|
// Label id start from 1, and increased by 1 for each new id. |
|
|
|
@@ -204,6 +261,9 @@ class AscendAutoMonadContext : public BaseContext { |
|
|
|
return out_para; |
|
|
|
} |
|
|
|
|
|
|
|
// Get or create a temporary parameter for the given abstract. |
|
|
|
AnfNodePtr GetTempParameter(const AbstractBasePtr &abs) { return param_pool_.GetParameter(abs); } |
|
|
|
|
|
|
|
const KernelGraphPtr &TopGraph() const { return top_graph_; } |
|
|
|
|
|
|
|
// Map kernel_graph to its call info. |
|
|
|
@@ -213,8 +273,8 @@ class AscendAutoMonadContext : public BaseContext { |
|
|
|
// The top graph. |
|
|
|
const KernelGraphPtr &top_graph_; |
|
|
|
|
|
|
|
// Map kernel_graph to its output parameter. |
|
|
|
std::unordered_map<KernelGraphPtr, AnfNodePtr> kg_out_param_; |
|
|
|
// The parameter pool that cache parameters for return value. |
|
|
|
ParameterPool param_pool_; |
|
|
|
|
|
|
|
// Current label id. |
|
|
|
uint32_t label_id_ = 1; |
|
|
|
@@ -521,9 +581,18 @@ class AscendAutoMonadConverter { |
|
|
|
auto label_node = LabelSet(call_site.return_label); |
|
|
|
AnfNodePtr output = call_site.out_param; |
|
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
|
// Let output depend on the label node, this ensures the |
|
|
|
// return label is set before output is used. |
|
|
|
output = MakeDepend(output, label_node); |
|
|
|
const bool is_single_call = call_site.label_indexes.empty(); |
|
|
|
if (is_single_call) { |
|
|
|
// For single call, let output depend on the label node, |
|
|
|
// this ensures the return label is set before output is used. |
|
|
|
output = MakeDepend(output, label_node); |
|
|
|
} else { |
|
|
|
// For multi-return call, assign result from temp parameter to |
|
|
|
// output parameter, this prevent result be overwritten by next call. |
|
|
|
auto tmp_param = context_.GetTempParameter(output->abstract()); |
|
|
|
output = AssignAll(output, tmp_param); |
|
|
|
monad_ = UpdateState(GetMonad(), output); |
|
|
|
} |
|
|
|
// Replace the the call/switch node with the output. |
|
|
|
ReplaceNode(cnode, output); |
|
|
|
return; |
|
|
|
@@ -603,12 +672,12 @@ class AscendAutoMonadConverter { |
|
|
|
if (return_points.empty()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
// Assign output according the return points. |
|
|
|
AssignOutput(return_points); |
|
|
|
// Single return point. |
|
|
|
if (return_points.size() == 1) { |
|
|
|
// Insert Assign for output parameter. |
|
|
|
auto &return_point = return_points.front(); |
|
|
|
AssignOutput(return_point); |
|
|
|
// Insert label_goto for return. |
|
|
|
auto &return_point = return_points.front(); |
|
|
|
auto return_goto = LabelGoto(return_point.call_site->return_label); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto); |
|
|
|
kernel_graph_->set_end_goto(return_goto); |
|
|
|
@@ -617,12 +686,9 @@ class AscendAutoMonadConverter { |
|
|
|
// Multi return points. |
|
|
|
std::vector<uint32_t> return_labels; |
|
|
|
return_labels.reserve(return_points.size()); |
|
|
|
for (auto &return_point : return_points) { |
|
|
|
// Assign output to out_params of each return point. |
|
|
|
AssignOutput(return_point); |
|
|
|
// Get return labels. |
|
|
|
return_labels.emplace_back(return_point.call_site->return_label); |
|
|
|
} |
|
|
|
// Get return labels from return points. |
|
|
|
std::transform(return_points.begin(), return_points.end(), std::back_inserter(return_labels), |
|
|
|
[](const ReturnPoint &return_point) { return return_point.call_site->return_label; }); |
|
|
|
// Insert label_switch for multi return points. |
|
|
|
auto &label_param = call_info_.label_param; |
|
|
|
MS_EXCEPTION_IF_NULL(label_param); |
|
|
|
@@ -631,11 +697,18 @@ class AscendAutoMonadConverter { |
|
|
|
kernel_graph_->set_end_goto(return_switch); |
|
|
|
} |
|
|
|
|
|
|
|
// Assign graph output to the output parameter for a return point. |
|
|
|
void AssignOutput(const ReturnPoint &return_point) { |
|
|
|
auto call_site = return_point.call_site; |
|
|
|
// Assign graph output to the output parameter. |
|
|
|
void AssignOutput(const std::vector<ReturnPoint> &return_points) { |
|
|
|
// For single call: we directly assign output to the output parameter of the call site; |
|
|
|
// For multi call: we assign output to a temp parameter, and let caller assign the |
|
|
|
// temp parameter to a output parameter after returned. |
|
|
|
auto call_site = return_points.front().call_site; |
|
|
|
MS_EXCEPTION_IF_NULL(call_site); |
|
|
|
auto assign_output = AssignAll(call_site->out_param, kernel_graph_->output()); |
|
|
|
const bool is_single_call = (return_points.size() == 1 && call_site->label_indexes.empty()); |
|
|
|
AnfNodePtr out_param = |
|
|
|
(is_single_call ? call_site->out_param : context_.GetTempParameter(kernel_graph_->output()->abstract())); |
|
|
|
MS_EXCEPTION_IF_NULL(out_param); |
|
|
|
auto assign_output = AssignAll(out_param, kernel_graph_->output()); |
|
|
|
monad_ = UpdateState(GetMonad(), assign_output); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -699,7 +772,7 @@ class AscendAutoMonadConverter { |
|
|
|
// For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode. |
|
|
|
AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared<Primitive>(prim->name())); } |
|
|
|
|
|
|
|
AnfNodePtr GetAssignMonad() { |
|
|
|
AnfNodePtr GetLinkMonad() { |
|
|
|
if (last_monad_ != nullptr) { |
|
|
|
return last_monad_; |
|
|
|
} |
|
|
|
@@ -708,7 +781,7 @@ class AscendAutoMonadConverter { |
|
|
|
|
|
|
|
// Make a assign cnode. |
|
|
|
CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) { |
|
|
|
auto monad = GetAssignMonad(); |
|
|
|
auto monad = (is_link ? GetLinkMonad() : GetMonad()); |
|
|
|
auto assign_prim = std::make_shared<Primitive>(prim::kPrimAssign->name()); |
|
|
|
if (is_link) { |
|
|
|
// Mark this assign is to link real argument to formal argument. |
|
|
|
|