Browse Source

!28244 insert cast for cpu weight node

Merge pull request !28244 from baihuawei/reconstruct_insert_cast
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
0b0ed389c4
2 changed files with 20 additions and 22 deletions
  1. +20
    -3
      mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.cc
  2. +0
    -19
      mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc

+ 20
- 3
mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.cc View File

@@ -89,9 +89,6 @@ void InsertCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
}
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
MS_EXCEPTION_IF_NULL(cur_input);
if (cur_input->isa<Parameter>() && AnfAlgo::IsParameterWeight(cur_input->cast<ParameterPtr>())) {
continue;
}
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
const abstract::BaseShapePtr origin_shape = AnfAlgo::GetOutputDetailShape(prev_node.first, prev_node.second);
@@ -101,6 +98,26 @@ void InsertCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cast);
cast->set_scope(cnode->scope());
cnode->set_input(input_index + 1, cast);
auto real_input = AnfAlgo::VisitKernel(cur_input, 0).first;
if (AnfAlgo::IsUpdateParameterKernel(cnode) && real_input->isa<Parameter>() &&
AnfAlgo::IsParameterWeight(real_input->cast<ParameterPtr>())) {
auto first_depend_node =
func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), cast, cnode});
first_depend_node->set_abstract(cast->abstract());
auto post_cast = AddCastOpNodeToGraph(func_graph, first_depend_node, dev_fmt, device_type, origin_type,
origin_shape, origin_type);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
kernel_graph->AddRefCorrespondPairs(std::make_pair(post_cast, 0), AnfAlgo::VisitKernel(cur_input, 0));
auto second_depend_node = func_graph->NewCNode(
{NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), cnode, post_cast});
second_depend_node->set_abstract(cnode->abstract());
auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, cnode, 0);
for (size_t j = 0; j < used_node_list->size(); j++) {
auto used_node = used_node_list->at(j).first;
utils::cast<CNodePtr>(used_node)->set_input(used_node_list->at(j).second, second_depend_node);
}
}
}
}
}


+ 0
- 19
mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc View File

@@ -44,24 +44,6 @@ bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) {
return false;
}

void UpdatePrevNotCNodeFormatDtype(const KernelAttr &kernel_attr, const std::vector<size_t> &input_not_cnode_indexes,
const CNodePtr &kernel_node) {
for (auto &input_index : input_not_cnode_indexes) {
auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first;
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
MS_EXCEPTION_IF_NULL(input_node);
std::vector<TypeId> output_types;
output_types.emplace_back(kernel_attr.GetInputAttr(input_index).first);
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(builder);
builder->SetOutputsFormat({kOpFormat_DEFAULT});
builder->SetOutputsDeviceType(output_types);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_node.get());
}
}
}

void GetOutputDtypes(const CNodePtr &kernel_node, std::vector<TypeId> *output_types) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
@@ -487,7 +469,6 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
if (matched.first || input_types.size() == input_not_cnode_indexes.size()) {
MS_LOG(INFO) << "Input format and dtype is matched";
GetOutputFormatsAndDtypes(kernel_node, selected_kernel_attr, &selected_output_formats, &selected_output_types);
UpdatePrevNotCNodeFormatDtype(selected_kernel_attr, input_not_cnode_indexes, kernel_node);
for (size_t index = 0; index < selected_kernel_attr.GetInputSize(); index++) {
input_types[index] = selected_kernel_attr.GetInputAttr(index).first;
input_formats.emplace_back(selected_kernel_attr.GetInputAttr(index).second);


Loading…
Cancel
Save