|
|
|
@@ -89,9 +89,6 @@ void InsertCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { |
|
|
|
}
|
|
|
|
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
|
|
|
|
MS_EXCEPTION_IF_NULL(cur_input);
|
|
|
|
if (cur_input->isa<Parameter>() && AnfAlgo::IsParameterWeight(cur_input->cast<ParameterPtr>())) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
|
|
|
|
const abstract::BaseShapePtr origin_shape = AnfAlgo::GetOutputDetailShape(prev_node.first, prev_node.second);
|
|
|
|
|
|
|
|
@@ -101,6 +98,26 @@ void InsertCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { |
|
|
|
MS_EXCEPTION_IF_NULL(cast);
|
|
|
|
cast->set_scope(cnode->scope());
|
|
|
|
cnode->set_input(input_index + 1, cast);
|
|
|
|
auto real_input = AnfAlgo::VisitKernel(cur_input, 0).first;
|
|
|
|
if (AnfAlgo::IsUpdateParameterKernel(cnode) && real_input->isa<Parameter>() &&
|
|
|
|
AnfAlgo::IsParameterWeight(real_input->cast<ParameterPtr>())) {
|
|
|
|
auto first_depend_node =
|
|
|
|
func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), cast, cnode});
|
|
|
|
first_depend_node->set_abstract(cast->abstract());
|
|
|
|
auto post_cast = AddCastOpNodeToGraph(func_graph, first_depend_node, dev_fmt, device_type, origin_type,
|
|
|
|
origin_shape, origin_type);
|
|
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
kernel_graph->AddRefCorrespondPairs(std::make_pair(post_cast, 0), AnfAlgo::VisitKernel(cur_input, 0));
|
|
|
|
auto second_depend_node = func_graph->NewCNode(
|
|
|
|
{NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), cnode, post_cast});
|
|
|
|
second_depend_node->set_abstract(cnode->abstract());
|
|
|
|
auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, cnode, 0);
|
|
|
|
for (size_t j = 0; j < used_node_list->size(); j++) {
|
|
|
|
auto used_node = used_node_list->at(j).first;
|
|
|
|
utils::cast<CNodePtr>(used_node)->set_input(used_node_list->at(j).second, second_depend_node);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|