|
|
|
@@ -13,17 +13,16 @@ |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
#include "backend/optimizer/graph_kernel/update_state_formatter.h" |
|
|
|
#include "backend/optimizer/graph_kernel/core/update_state_formatter.h" |
|
|
|
|
|
|
|
#include <vector> |
|
|
|
#include <set> |
|
|
|
#include <memory> |
|
|
|
#include <utility> |
|
|
|
#include <algorithm> |
|
|
|
#include "backend/session/anf_runtime_algorithm.h" |
|
|
|
#include "backend/kernel_compiler/common_utils.h" |
|
|
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" |
|
|
|
#include "ir/anf.h" |
|
|
|
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h" |
|
|
|
#include "backend/optimizer/graph_kernel/core/graph_kernel_callback.h" |
|
|
|
#include "backend/optimizer/graph_kernel/core/eliminate_redundant_output.h" |
|
|
|
|
|
|
|
namespace mindspore::graphkernel { |
|
|
|
@@ -52,7 +51,7 @@ AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdateState(const AnfNodePtrList |
|
|
|
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem); |
|
|
|
tuple_getitem->set_abstract(node_abstract[i]); |
|
|
|
tuple_getitem->set_kernel_info(std::make_shared<device::KernelInfo>()); |
|
|
|
Callback::Instance()->SetEmptyKernelInfo(tuple_getitem); |
|
|
|
result.push_back(tuple_getitem); |
|
|
|
} |
|
|
|
} else { |
|
|
|
@@ -103,12 +102,12 @@ bool ShrinkUpdateState::Run(const FuncGraphPtr &func_graph) { |
|
|
|
mt_inputs.insert(mt_inputs.begin(), NewValueNode(prim::kPrimMakeTuple)); |
|
|
|
auto mt_node = func_graph->NewCNode(mt_inputs); |
|
|
|
mt_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list)); |
|
|
|
mt_node->set_kernel_info(std::make_shared<device::KernelInfo>()); |
|
|
|
Callback::Instance()->SetEmptyKernelInfo(mt_node); |
|
|
|
|
|
|
|
AnfNodePtrList inputs = {cnode->input(0), cnode->input(1), mt_node}; |
|
|
|
auto new_node = func_graph->NewCNode(inputs); |
|
|
|
new_node->set_abstract(node->abstract()); |
|
|
|
new_node->set_kernel_info(std::make_shared<device::KernelInfo>()); |
|
|
|
Callback::Instance()->SetEmptyKernelInfo(new_node); |
|
|
|
(void)mng->Replace(node, new_node); |
|
|
|
changed = true; |
|
|
|
} |
|
|
|
@@ -125,7 +124,7 @@ bool ExtendOutputForUpdateState::Run(const FuncGraphPtr &func_graph) { |
|
|
|
if (getitems_.empty()) continue; |
|
|
|
FindIndexesToUpdateState(mng); |
|
|
|
if (indexes_.empty()) continue; |
|
|
|
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); |
|
|
|
auto sub_func_graph = GetCNodeFuncGraph(node); |
|
|
|
FilterIndexes(sub_func_graph); |
|
|
|
if (indexes_.empty()) continue; |
|
|
|
for (auto idx : indexes_) { |
|
|
|
@@ -133,7 +132,7 @@ bool ExtendOutputForUpdateState::Run(const FuncGraphPtr &func_graph) { |
|
|
|
} |
|
|
|
} |
|
|
|
if (changed) { |
|
|
|
UpdateMng(mng, func_graph); |
|
|
|
GkUtils::UpdateFuncGraphManager(mng, func_graph); |
|
|
|
std::make_shared<SpreadUpdateState>()->Run(func_graph); |
|
|
|
std::make_shared<EliminateHangingOutput>()->Run(func_graph); |
|
|
|
} |