Merge pull request !3987 from zhoufeng/graph-compile-performacetags/v0.7.0-beta
| @@ -52,13 +52,13 @@ TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const { | |||||
| return outputs_device_type_[output_index]; | return outputs_device_type_[output_index]; | ||||
| } | } | ||||
| std::vector<std::string> KernelBuildInfo::GetAllInputFormats() const { return inputs_format_; } | |||||
| const std::vector<std::string> &KernelBuildInfo::GetAllInputFormats() const { return inputs_format_; } | |||||
| std::vector<std::string> KernelBuildInfo::GetAllOutputFormats() const { return outputs_format_; } | |||||
| const std::vector<std::string> &KernelBuildInfo::GetAllOutputFormats() const { return outputs_format_; } | |||||
| std::vector<TypeId> KernelBuildInfo::GetAllInputDeviceTypes() const { return inputs_device_type_; } | |||||
| const std::vector<TypeId> &KernelBuildInfo::GetAllInputDeviceTypes() const { return inputs_device_type_; } | |||||
| std::vector<TypeId> KernelBuildInfo::GetAllOutputDeviceTypes() const { return outputs_device_type_; } | |||||
| const std::vector<TypeId> &KernelBuildInfo::GetAllOutputDeviceTypes() const { return outputs_device_type_; } | |||||
| size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); } | size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); } | ||||
| @@ -63,13 +63,13 @@ class KernelBuildInfo { | |||||
| std::vector<Axis> GetOutputReshapeType(size_t input_index) const; | std::vector<Axis> GetOutputReshapeType(size_t input_index) const; | ||||
| std::vector<std::string> GetAllInputFormats() const; | |||||
| const std::vector<std::string> &GetAllInputFormats() const; | |||||
| std::vector<std::string> GetAllOutputFormats() const; | |||||
| const std::vector<std::string> &GetAllOutputFormats() const; | |||||
| std::vector<TypeId> GetAllInputDeviceTypes() const; | |||||
| const std::vector<TypeId> &GetAllInputDeviceTypes() const; | |||||
| std::vector<TypeId> GetAllOutputDeviceTypes() const; | |||||
| const std::vector<TypeId> &GetAllOutputDeviceTypes() const; | |||||
| std::vector<std::vector<Axis>> GetAllOutputReshapeType() const; | std::vector<std::vector<Axis>> GetAllOutputReshapeType() const; | ||||
| @@ -25,6 +25,10 @@ namespace kernel { | |||||
| * @brief fuse op and return a callable mod | * @brief fuse op and return a callable mod | ||||
| */ | */ | ||||
| struct FusionScopeInfo { | struct FusionScopeInfo { | ||||
| FusionScopeInfo() {} | |||||
| FusionScopeInfo(int32_t id, const std::vector<AnfNodePtr> &in, const std::vector<AnfNodePtr> &comp, | |||||
| const std::vector<AnfNodePtr> &out) | |||||
| : scope_id(id), input_nodes(in), compute_nodes(comp), output_nodes(out) {} | |||||
| int32_t scope_id; | int32_t scope_id; | ||||
| std::vector<AnfNodePtr> input_nodes; | std::vector<AnfNodePtr> input_nodes; | ||||
| std::vector<AnfNodePtr> compute_nodes; | std::vector<AnfNodePtr> compute_nodes; | ||||
| @@ -59,13 +59,13 @@ class OpIOInfo { | |||||
| ~OpIOInfo() = default; | ~OpIOInfo() = default; | ||||
| int index() const { return index_; } | int index() const { return index_; } | ||||
| std::string name() const { return name_; } | |||||
| const std::string &name() const { return name_; } | |||||
| bool need_compile() const { return need_compile_; } | bool need_compile() const { return need_compile_; } | ||||
| std::string param_type() const { return param_type_; } | |||||
| std::string reshape_type() const { return reshape_type_; } | |||||
| std::string shape() const { return shape_; } | |||||
| std::vector<std::string> dtypes() const { return dtypes_; } | |||||
| std::vector<std::string> formats() const { return formats_; } | |||||
| const std::string ¶m_type() const { return param_type_; } | |||||
| const std::string &reshape_type() const { return reshape_type_; } | |||||
| const std::string &shape() const { return shape_; } | |||||
| const std::vector<std::string> &dtypes() const { return dtypes_; } | |||||
| const std::vector<std::string> &formats() const { return formats_; } | |||||
| void set_index(const int index) { index_ = index; } | void set_index(const int index) { index_ = index; } | ||||
| void set_name(const std::string &name) { name_ = name; } | void set_name(const std::string &name) { name_ = name; } | ||||
| @@ -336,13 +336,11 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im | |||||
| << ", current op num: " << op_info_.size(); | << ", current op num: " << op_info_.size(); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| std::string target_processor = is_gpu ? kCUDA : kAiCore; | |||||
| for (const auto &op_info : op_info_) { | for (const auto &op_info : op_info_) { | ||||
| MS_EXCEPTION_IF_NULL(op_info); | MS_EXCEPTION_IF_NULL(op_info); | ||||
| if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) { | if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) { | ||||
| auto akg_processor_match = [&]() { | |||||
| return is_gpu ? op_info->processor() == kCUDA : op_info->processor() == kAiCore; | |||||
| }; | |||||
| if (imply_type != kAKG || akg_processor_match()) { | |||||
| if (imply_type != kAKG || op_info->processor() == target_processor) { | |||||
| return op_info; | return op_info; | ||||
| } | } | ||||
| } | } | ||||
| @@ -82,7 +82,6 @@ void TbeKernelSelect::TbeMetadataInfoEx() { | |||||
| } | } | ||||
| // check support | // check support | ||||
| FilterInVaildKernelInfo(); | FilterInVaildKernelInfo(); | ||||
| MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select."; | |||||
| } | } | ||||
| void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { | void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { | ||||
| @@ -221,38 +220,37 @@ void TbeKernelSelect::FilterInVaildKernelInfo() { | |||||
| MS_LOG(INFO) << "Warning: get kernel build info failed."; | MS_LOG(INFO) << "Warning: get kernel build info failed."; | ||||
| return; | return; | ||||
| } | } | ||||
| auto kernel_build_info_iter = kernel_info_list_->begin(); | |||||
| while (kernel_build_info_iter != kernel_info_list_->end()) { | |||||
| if (!FilterInVaildShape(kernel_build_info_iter)) { | |||||
| MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*kernel_build_info_iter)->ToString(); | |||||
| kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); | |||||
| std::vector<std::shared_ptr<KernelBuildInfo>> new_kernel_info_list; | |||||
| for (auto iter = kernel_info_list_->begin(); iter != kernel_info_list_->end(); ++iter) { | |||||
| if (!FilterInVaildShape(iter)) { | |||||
| MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*iter)->ToString(); | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (!TbeCheckSupported(kernel_build_info_iter)) { | |||||
| MS_LOG(INFO) << "Check support shape, filter item info: " << (*kernel_build_info_iter)->ToString(); | |||||
| kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); | |||||
| if (!TbeCheckSupported(iter)) { | |||||
| MS_LOG(INFO) << "Check support shape, filter item info: " << (*iter)->ToString(); | |||||
| continue; | continue; | ||||
| } | } | ||||
| kernel_build_info_iter++; | |||||
| new_kernel_info_list.emplace_back(*iter); | |||||
| } | } | ||||
| (*kernel_info_list_) = new_kernel_info_list; | |||||
| } | } | ||||
| bool TbeKernelSelect::FilterInVaildShape( | bool TbeKernelSelect::FilterInVaildShape( | ||||
| const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { | const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { | ||||
| MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); | MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); | ||||
| auto kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats(); | |||||
| const auto &kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats(); | |||||
| for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) { | for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) { | ||||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); | auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); | ||||
| auto format = kernel_build_info_inputs_format.at(i); | |||||
| const auto &format = kernel_build_info_inputs_format[i]; | |||||
| if (!IsShapeMatchFormat(shape, format)) { | if (!IsShapeMatchFormat(shape, format)) { | ||||
| MS_LOG(INFO) << "The " << i << "th input check failed."; | MS_LOG(INFO) << "The " << i << "th input check failed."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| auto kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats(); | |||||
| const auto &kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats(); | |||||
| for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) { | for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) { | ||||
| auto shape = AnfAlgo::GetOutputInferShape(cnode_ptr_, j); | auto shape = AnfAlgo::GetOutputInferShape(cnode_ptr_, j); | ||||
| auto format = kernel_build_info_outputs_format.at(j); | |||||
| const auto &format = kernel_build_info_outputs_format[j]; | |||||
| if (!IsShapeMatchFormat(shape, format)) { | if (!IsShapeMatchFormat(shape, format)) { | ||||
| MS_LOG(INFO) << "The " << j << "th input check failed."; | MS_LOG(INFO) << "The " << j << "th input check failed."; | ||||
| return false; | return false; | ||||
| @@ -344,12 +342,12 @@ bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_ind | |||||
| size_t io_info_num = ios_info.size(); | size_t io_info_num = ios_info.size(); | ||||
| for (; io_info_index < io_info_num && real_io_tensor_index < real_io_tensor_num; io_info_index++) { | for (; io_info_index < io_info_num && real_io_tensor_index < real_io_tensor_num; io_info_index++) { | ||||
| std::shared_ptr<OpIOInfo> io_info_item = ios_info[io_info_index]; | std::shared_ptr<OpIOInfo> io_info_item = ios_info[io_info_index]; | ||||
| auto kernel_build_info_dtype = io_info_item->dtypes().at(kernel_build_info_index); | |||||
| const auto &kernel_build_info_dtype = io_info_item->dtypes()[kernel_build_info_index]; | |||||
| std::string kernel_build_info_format; | std::string kernel_build_info_format; | ||||
| if (!io_info_item->formats().empty()) { | if (!io_info_item->formats().empty()) { | ||||
| kernel_build_info_format = io_info_item->formats().at(kernel_build_info_index); | |||||
| kernel_build_info_format = io_info_item->formats()[kernel_build_info_index]; | |||||
| } | } | ||||
| std::string io_param_type = io_info_item->param_type(); | |||||
| const std::string &io_param_type = io_info_item->param_type(); | |||||
| std::vector<Axis> reshape_type; | std::vector<Axis> reshape_type; | ||||
| StringToAxisVector(io_info_item->reshape_type(), &reshape_type); | StringToAxisVector(io_info_item->reshape_type(), &reshape_type); | ||||
| if (io_param_type == kParamTypeDynamic) { | if (io_param_type == kParamTypeDynamic) { | ||||
| @@ -367,6 +365,7 @@ bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_ind | |||||
| } | } | ||||
| dynamic_input_index++; | dynamic_input_index++; | ||||
| real_io_tensor_index += dynamic_input_size; | real_io_tensor_index += dynamic_input_size; | ||||
| } else { | } else { | ||||
| if (ios_info.size() != 1) { | if (ios_info.size() != 1) { | ||||
| MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output."; | MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output."; | ||||
| @@ -388,7 +387,6 @@ bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_ind | |||||
| MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type; | MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type; | ||||
| } | } | ||||
| } | } | ||||
| if (io_info_index != io_info_num) { | if (io_info_index != io_info_num) { | ||||
| MS_LOG(INFO) << "Warning: io_info_index(" << io_info_index << ") != io_info_num(" << io_info_num | MS_LOG(INFO) << "Warning: io_info_index(" << io_info_index << ") != io_info_num(" << io_info_num | ||||
| << "), this node may has optional input/output."; | << "), this node may has optional input/output."; | ||||
| @@ -51,11 +51,11 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i | |||||
| AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { | const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { | ||||
| AnfNodePtr trans_node = nullptr; | AnfNodePtr trans_node = nullptr; | ||||
| AnfNodePtr input_node = node; | |||||
| AnfNodePtr input_node = nullptr; | |||||
| CNodePtr trans_data = nullptr; | CNodePtr trans_data = nullptr; | ||||
| std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0); | std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0); | ||||
| std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT; | std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT; | ||||
| std::vector<kernel::Axis> padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); | |||||
| std::vector<kernel::Axis> padding_axis; | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| // if insert transdata for input we need to change the input | // if insert transdata for input we need to change the input | ||||
| if (is_insert_input) { | if (is_insert_input) { | ||||
| @@ -66,12 +66,17 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); | dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); | ||||
| input_node = AnfAlgo::GetInputNode(cnode, insert_index); | input_node = AnfAlgo::GetInputNode(cnode, insert_index); | ||||
| padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index); | padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index); | ||||
| } else { | |||||
| input_node = node; | |||||
| padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); | |||||
| } | } | ||||
| auto input_node_out_shape = AnfAlgo::GetOutputInferShape(input_node, 0); | |||||
| bool need_padding = false; | bool need_padding = false; | ||||
| if (is_insert_input) { | if (is_insert_input) { | ||||
| need_padding = (trans::IsNeedPadding(dst_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); | |||||
| need_padding = (trans::IsNeedPadding(dst_format, input_node_out_shape.size())); | |||||
| } else { | } else { | ||||
| need_padding = (trans::IsNeedPadding(input_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); | |||||
| need_padding = (trans::IsNeedPadding(input_format, input_node_out_shape.size())); | |||||
| } | } | ||||
| if (!need_padding) { | if (!need_padding) { | ||||
| // don't need padding insert transdata only | // don't need padding insert transdata only | ||||
| @@ -80,8 +85,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| } else if (is_insert_input) { | } else if (is_insert_input) { | ||||
| // if need padding & is input need insert a transdata | // if need padding & is input need insert a transdata | ||||
| // reshape[padding shape] -> transdata[padding shape] -> node | // reshape[padding shape] -> transdata[padding shape] -> node | ||||
| auto padding_shape = | |||||
| trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0)); | |||||
| auto padding_shape = trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, 0)); | |||||
| auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); | auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); | ||||
| trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); | trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); | ||||
| trans_node = trans_data; | trans_node = trans_data; | ||||
| @@ -89,8 +93,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| // if need padding & is output need insert a transdata | // if need padding & is output need insert a transdata | ||||
| // node -> transdata[padding shape] -> reshape[ori_shape] | // node -> transdata[padding shape] -> reshape[ori_shape] | ||||
| trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); | trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); | ||||
| auto reshape_node = | |||||
| CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); | |||||
| auto reshape_node = CreateReshapeNode(func_graph, trans_data, kernel_select, input_node_out_shape); | |||||
| trans_node = reshape_node; | trans_node = reshape_node; | ||||
| } | } | ||||
| // refresh the transdata's format to ori format & dst format | // refresh the transdata's format to ori format & dst format | ||||
| @@ -140,10 +143,10 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const | |||||
| const KernelSelectPtr &kernel_select) { | const KernelSelectPtr &kernel_select) { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||||
| make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | |||||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | ||||
| for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) { | |||||
| size_t out_num = AnfAlgo::GetOutputTensorNum(node); | |||||
| for (size_t output_idx = 0; output_idx < out_num; ++output_idx) { | |||||
| std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx); | std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx); | ||||
| if (output_format == kOpFormat_NC1KHKWHWC0) { | if (output_format == kOpFormat_NC1KHKWHWC0) { | ||||
| MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node " | MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node " | ||||
| @@ -151,12 +154,12 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const | |||||
| } | } | ||||
| auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); | auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); | ||||
| std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); | std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); | ||||
| if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { | |||||
| if (origin_shape.size() > 1 && kCommonFormatSet.find(output_format) == kCommonFormatSet.end()) { | |||||
| auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false); | auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false); | ||||
| if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) { | if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) { | ||||
| kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0); | kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0); | ||||
| } | } | ||||
| make_tuple_inputs.emplace_back(trans_op); | |||||
| make_tuple_inputs.push_back(trans_op); | |||||
| } else { | } else { | ||||
| // No need insert trans op. | // No need insert trans op. | ||||
| make_tuple_inputs.push_back(tuple_getitem); | make_tuple_inputs.push_back(tuple_getitem); | ||||
| @@ -188,15 +191,11 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||||
| const bool need_padding, const std::string &op_name) { | const bool need_padding, const std::string &op_name) { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(input); | MS_EXCEPTION_IF_NULL(input); | ||||
| std::vector<AnfNodePtr> trans_inputs; | |||||
| auto prim = std::make_shared<Primitive>(op_name); | |||||
| trans_inputs.push_back(NewValueNode(prim)); | |||||
| trans_inputs.push_back(input); | |||||
| CNodePtr trans_node = func_graph->NewCNode(trans_inputs); | |||||
| CNodePtr trans_node = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(op_name)), input}); | |||||
| MS_EXCEPTION_IF_NULL(trans_node); | MS_EXCEPTION_IF_NULL(trans_node); | ||||
| auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); | |||||
| if (need_padding) { | if (need_padding) { | ||||
| // if need padding we should set the transdata node's shape to the padding shape | // if need padding we should set the transdata node's shape to the padding shape | ||||
| auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); | |||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, | AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, | ||||
| {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)}, | {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)}, | ||||
| trans_node.get()); | trans_node.get()); | ||||
| @@ -224,11 +223,7 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| std::string input_format = format; | std::string input_format = format; | ||||
| std::string output_format = format; | std::string output_format = format; | ||||
| std::vector<AnfNodePtr> new_cast_inputs; | |||||
| auto prim = std::make_shared<Primitive>(prim::kPrimCast->name()); | |||||
| new_cast_inputs.push_back(NewValueNode(prim)); | |||||
| new_cast_inputs.push_back(input); | |||||
| CNodePtr cast = func_graph->NewCNode(new_cast_inputs); | |||||
| CNodePtr cast = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())), input}); | |||||
| MS_EXCEPTION_IF_NULL(cast); | MS_EXCEPTION_IF_NULL(cast); | ||||
| // set kernel build info | // set kernel build info | ||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | ||||
| @@ -280,7 +275,8 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; | std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; | ||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | |||||
| size_t in_num = AnfAlgo::GetInputTensorNum(cnode); | |||||
| for (size_t input_index = 0; input_index < in_num; ++input_index) { | |||||
| AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select); | AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select); | ||||
| MS_EXCEPTION_IF_NULL(input_node); | MS_EXCEPTION_IF_NULL(input_node); | ||||
| new_inputs.push_back(input_node); | new_inputs.push_back(input_node); | ||||
| @@ -301,8 +297,10 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; | std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; | ||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | |||||
| const auto infer_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); | |||||
| size_t in_num = AnfAlgo::GetInputTensorNum(cnode); | |||||
| for (size_t input_index = 0; input_index < in_num; ++input_index) { | |||||
| auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index); | |||||
| const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second); | |||||
| TypeId origin_type(kTypeUnknown); | TypeId origin_type(kTypeUnknown); | ||||
| auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); | auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); | ||||
| auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); | auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); | ||||
| @@ -311,20 +309,19 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod | |||||
| // weight | // weight | ||||
| origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index); | origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index); | ||||
| if (origin_type == kTypeUnknown) { | if (origin_type == kTypeUnknown) { | ||||
| origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index); | |||||
| origin_type = AnfAlgo::GetOutputDeviceDataType(prev_node.first, prev_node.second); | |||||
| } | } | ||||
| } else { | } else { | ||||
| // feature map | // feature map | ||||
| origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); | |||||
| origin_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second); | |||||
| } | } | ||||
| const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); | const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); | ||||
| const std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_index); | |||||
| const TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); | |||||
| const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second); | |||||
| // In graph kernel, we check parameter, | // In graph kernel, we check parameter, | ||||
| // the eliminate pass will not eliminate this case, so we just do not insert the noused cast. | // the eliminate pass will not eliminate this case, so we just do not insert the noused cast. | ||||
| if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) { | if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) { | ||||
| new_inputs.push_back(cur_input); | new_inputs.push_back(cur_input); | ||||
| } else if (origin_type != device_type) { | |||||
| } else if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) { | |||||
| auto cast = | auto cast = | ||||
| AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type); | AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type); | ||||
| MS_EXCEPTION_IF_NULL(cast); | MS_EXCEPTION_IF_NULL(cast); | ||||
| @@ -120,8 +120,8 @@ kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector<AnfNodePtr | |||||
| std::vector<TypeId> inputs_data_type; | std::vector<TypeId> inputs_data_type; | ||||
| for (const auto &input : inputs_list) { | for (const auto &input : inputs_list) { | ||||
| auto real_input = AnfAlgo::VisitKernel(input, 0); | auto real_input = AnfAlgo::VisitKernel(input, 0); | ||||
| inputs_format.push_back(AnfAlgo::GetOutputFormat(real_input.first, real_input.second)); | |||||
| inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(real_input.first, real_input.second)); | |||||
| inputs_format.emplace_back(AnfAlgo::GetOutputFormat(real_input.first, real_input.second)); | |||||
| inputs_data_type.emplace_back(AnfAlgo::GetOutputDeviceDataType(real_input.first, real_input.second)); | |||||
| } | } | ||||
| // outputs format and data type | // outputs format and data type | ||||
| std::vector<std::string> outputs_format; | std::vector<std::string> outputs_format; | ||||
| @@ -130,13 +130,13 @@ kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector<AnfNodePtr | |||||
| if (AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { | if (AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { | ||||
| auto tuple_getitem = output->cast<CNodePtr>(); | auto tuple_getitem = output->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(tuple_getitem); | MS_EXCEPTION_IF_NULL(tuple_getitem); | ||||
| outputs_format.push_back(AnfAlgo::GetOutputFormat( | |||||
| outputs_format.emplace_back(AnfAlgo::GetOutputFormat( | |||||
| tuple_getitem->input(1), IntToSize(GetValue<int>(GetValueNode(tuple_getitem->input(2)))))); | tuple_getitem->input(1), IntToSize(GetValue<int>(GetValueNode(tuple_getitem->input(2)))))); | ||||
| outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType( | |||||
| outputs_data_type.emplace_back(AnfAlgo::GetOutputDeviceDataType( | |||||
| tuple_getitem->input(1), IntToSize(GetValue<int>(GetValueNode(tuple_getitem->input(2)))))); | tuple_getitem->input(1), IntToSize(GetValue<int>(GetValueNode(tuple_getitem->input(2)))))); | ||||
| } else { | } else { | ||||
| outputs_format.push_back(AnfAlgo::GetOutputFormat(output, 0)); | |||||
| outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(output, 0)); | |||||
| outputs_format.emplace_back(AnfAlgo::GetOutputFormat(output, 0)); | |||||
| outputs_data_type.emplace_back(AnfAlgo::GetOutputDeviceDataType(output, 0)); | |||||
| } | } | ||||
| } | } | ||||
| builder.SetInputsFormat(inputs_format); | builder.SetInputsFormat(inputs_format); | ||||
| @@ -229,7 +229,7 @@ void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph, | |||||
| for (auto &buffer_fusion_info : *buffer_fusion_infos) { | for (auto &buffer_fusion_info : *buffer_fusion_infos) { | ||||
| auto fusion_id = buffer_fusion_info.first; | auto fusion_id = buffer_fusion_info.first; | ||||
| auto fusion_info = buffer_fusion_info.second; | |||||
| const auto &fusion_info = buffer_fusion_info.second; | |||||
| for (const auto &node : fusion_info.anf_nodes) { | for (const auto &node : fusion_info.anf_nodes) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| @@ -237,10 +237,10 @@ void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph, | |||||
| auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0); | auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0); | ||||
| if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) == | if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) == | ||||
| fusion_info.anf_nodes.end()) { | fusion_info.anf_nodes.end()) { | ||||
| if (std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(), | |||||
| (*buffer_fusion_infos)[fusion_id].inputs_list.end(), | |||||
| cnode->input(idx)) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) { | |||||
| (*buffer_fusion_infos)[fusion_id].inputs_list.push_back(cnode->input(idx)); | |||||
| if (auto in = cnode->input(idx); std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(), | |||||
| (*buffer_fusion_infos)[fusion_id].inputs_list.end(), | |||||
| in) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) { | |||||
| (*buffer_fusion_infos)[fusion_id].inputs_list.push_back(in); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -277,7 +277,7 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, | |||||
| for (auto &buffer_fusion_info : *buffer_fusion_infos) { | for (auto &buffer_fusion_info : *buffer_fusion_infos) { | ||||
| auto fusion_id = buffer_fusion_info.first; | auto fusion_id = buffer_fusion_info.first; | ||||
| auto fusion_info = buffer_fusion_info.second; | |||||
| const auto &fusion_info = buffer_fusion_info.second; | |||||
| for (const auto &node : fusion_info.anf_nodes) { | for (const auto &node : fusion_info.anf_nodes) { | ||||
| if (AnfAlgo::GetOutputTensorNum(node) == 1) { | if (AnfAlgo::GetOutputTensorNum(node) == 1) { | ||||
| for (auto use_node : manager->node_users()[node]) { | for (auto use_node : manager->node_users()[node]) { | ||||
| @@ -294,7 +294,7 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, | |||||
| std::back_inserter(tuple_getitem_nodes), | std::back_inserter(tuple_getitem_nodes), | ||||
| [](const std::pair<AnfNodePtr, int> &use_node) { return use_node.first; }); | [](const std::pair<AnfNodePtr, int> &use_node) { return use_node.first; }); | ||||
| std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); | std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); | ||||
| for (auto getitem : tuple_getitem_nodes) { | |||||
| for (auto &getitem : tuple_getitem_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(getitem); | MS_EXCEPTION_IF_NULL(getitem); | ||||
| auto getitem_ptr = getitem->cast<CNodePtr>(); | auto getitem_ptr = getitem->cast<CNodePtr>(); | ||||
| auto input2 = getitem_ptr->input(2); | auto input2 = getitem_ptr->input(2); | ||||
| @@ -304,7 +304,7 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, | |||||
| (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(stub_node); | (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(stub_node); | ||||
| } | } | ||||
| prev_idx = output_idx + 1; | prev_idx = output_idx + 1; | ||||
| for (auto item_use_node : manager->node_users()[getitem]) { | |||||
| for (auto &item_use_node : manager->node_users()[getitem]) { | |||||
| if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), item_use_node.first) == | if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), item_use_node.first) == | ||||
| fusion_info.anf_nodes.end()) { | fusion_info.anf_nodes.end()) { | ||||
| (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(getitem); | (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(getitem); | ||||
| @@ -365,31 +365,25 @@ bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| bool change = false; | bool change = false; | ||||
| std::unordered_map<int32_t, BufferFusionInfo_t> buffer_fusion_infos; | std::unordered_map<int32_t, BufferFusionInfo_t> buffer_fusion_infos; | ||||
| buffer_fusion_infos.clear(); | |||||
| GetBufferFusionInfo(kernel_graph, &buffer_fusion_infos); | GetBufferFusionInfo(kernel_graph, &buffer_fusion_infos); | ||||
| std::vector<mindspore::kernel::FusionScopeInfo> fusion_scope_infos; | std::vector<mindspore::kernel::FusionScopeInfo> fusion_scope_infos; | ||||
| for (auto &buffer_fusion_info : buffer_fusion_infos) { | |||||
| mindspore::kernel::FusionScopeInfo fusion_scope_info; | |||||
| fusion_scope_info.scope_id = buffer_fusion_info.first; | |||||
| fusion_scope_info.input_nodes = buffer_fusion_info.second.inputs_list; | |||||
| fusion_scope_info.compute_nodes = buffer_fusion_info.second.anf_nodes; | |||||
| fusion_scope_info.output_nodes = buffer_fusion_info.second.outputs_list; | |||||
| fusion_scope_infos.push_back(fusion_scope_info); | |||||
| #ifdef DEBUG | |||||
| DumpFusionScopeInfo(fusion_scope_info); | |||||
| #endif | |||||
| } | |||||
| std::transform( | |||||
| buffer_fusion_infos.begin(), buffer_fusion_infos.end(), std::back_inserter(fusion_scope_infos), | |||||
| [](const std::pair<int32_t, BufferFusionInfo_t> &buffer_fusion_info) -> mindspore::kernel::FusionScopeInfo { | |||||
| return mindspore::kernel::FusionScopeInfo(buffer_fusion_info.first, buffer_fusion_info.second.inputs_list, | |||||
| buffer_fusion_info.second.anf_nodes, | |||||
| buffer_fusion_info.second.outputs_list); | |||||
| }); | |||||
| auto kernel_mods = mindspore::kernel::KernelFusion(fusion_scope_infos); | auto kernel_mods = mindspore::kernel::KernelFusion(fusion_scope_infos); | ||||
| std::vector<int32_t> fusion_ids; | |||||
| std::set<int32_t> fusion_ids; | |||||
| for (auto &buffer_fusion_info : buffer_fusion_infos) { | for (auto &buffer_fusion_info : buffer_fusion_infos) { | ||||
| MS_LOG(DEBUG) << "anf node size: " << buffer_fusion_info.second.anf_nodes.size() | MS_LOG(DEBUG) << "anf node size: " << buffer_fusion_info.second.anf_nodes.size() | ||||
| << ", inputs_list size: " << buffer_fusion_info.second.inputs_list.size() | << ", inputs_list size: " << buffer_fusion_info.second.inputs_list.size() | ||||
| << ", outputs list size: " << buffer_fusion_info.second.outputs_list.size(); | << ", outputs list size: " << buffer_fusion_info.second.outputs_list.size(); | ||||
| fusion_ids.push_back(buffer_fusion_info.first); | |||||
| fusion_ids.insert(buffer_fusion_info.first); | |||||
| } | } | ||||
| // Replace fusion op from return to head | // Replace fusion op from return to head | ||||
| std::sort(fusion_ids.begin(), fusion_ids.end()); | |||||
| for (auto &fusion_id : fusion_ids) { | for (auto &fusion_id : fusion_ids) { | ||||
| // Get kernel mod when supporting tbe | // Get kernel mod when supporting tbe | ||||
| if (kernel_mods.find(fusion_id) == kernel_mods.end() || kernel_mods[fusion_id] == nullptr) { | if (kernel_mods.find(fusion_id) == kernel_mods.end() || kernel_mods[fusion_id] == nullptr) { | ||||
| @@ -414,9 +408,10 @@ bool UbPatternFusion::ReplaceFusionOp(std::unordered_map<int32_t, BufferFusionIn | |||||
| std::vector<TypeId> types; | std::vector<TypeId> types; | ||||
| std::vector<std::vector<size_t>> shapes; | std::vector<std::vector<size_t>> shapes; | ||||
| for (const auto &out_node : buffer_fusion_info.outputs_list) { | for (const auto &out_node : buffer_fusion_info.outputs_list) { | ||||
| for (size_t idx = 0; idx < AnfAlgo::GetOutputTensorNum(out_node); ++idx) { | |||||
| types.push_back(AnfAlgo::GetOutputInferDataType(out_node, idx)); | |||||
| shapes.push_back(AnfAlgo::GetOutputInferShape(out_node, idx)); | |||||
| size_t out_num = AnfAlgo::GetOutputTensorNum(out_node); | |||||
| for (size_t idx = 0; idx < out_num; ++idx) { | |||||
| types.emplace_back(AnfAlgo::GetOutputInferDataType(out_node, idx)); | |||||
| shapes.emplace_back(AnfAlgo::GetOutputInferShape(out_node, idx)); | |||||
| } | } | ||||
| } | } | ||||
| if (types.empty() || shapes.empty()) { | if (types.empty() || shapes.empty()) { | ||||
| @@ -30,12 +30,13 @@ namespace { | |||||
| bool CheckFormatForConsistency(const CNodePtr &node, const size_t input_index) { | bool CheckFormatForConsistency(const CNodePtr &node, const size_t input_index) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| // get prior node's device output format | // get prior node's device output format | ||||
| string pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(node, input_index); | |||||
| auto prev_node = AnfAlgo::GetPrevNodeOutput(node, input_index); | |||||
| string pre_output_format = AnfAlgo::GetOutputFormat(prev_node.first, prev_node.second); | |||||
| string selected_input_format = AnfAlgo::GetInputFormat(node, input_index); | string selected_input_format = AnfAlgo::GetInputFormat(node, input_index); | ||||
| if (pre_output_format == selected_input_format) { | if (pre_output_format == selected_input_format) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| auto input_origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, input_index); | |||||
| auto input_origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second); | |||||
| if (pre_output_format == kOpFormat_DEFAULT || selected_input_format == kOpFormat_DEFAULT) { | if (pre_output_format == kOpFormat_DEFAULT || selected_input_format == kOpFormat_DEFAULT) { | ||||
| string checking_format = (pre_output_format == kOpFormat_DEFAULT) ? selected_input_format : pre_output_format; | string checking_format = (pre_output_format == kOpFormat_DEFAULT) ? selected_input_format : pre_output_format; | ||||
| // when input shape size is 1D, default format and NC1HWC0 are compatible | // when input shape size is 1D, default format and NC1HWC0 are compatible | ||||
| @@ -87,7 +88,8 @@ const AnfNodePtr CheckConsistency::Process(const FuncGraphPtr &, const AnfNodePt | |||||
| for (auto &t : todos) { | for (auto &t : todos) { | ||||
| CNodePtr cnode = t->cast<CNodePtr>(); | CNodePtr cnode = t->cast<CNodePtr>(); | ||||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) { | |||||
| size_t in_num = AnfAlgo::GetInputTensorNum(cnode); | |||||
| for (size_t i = 0; i < in_num; ++i) { | |||||
| if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) { | if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) { | ||||
| MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "[" | MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "[" | ||||
| << cnode->DebugString() << "]"; | << cnode->DebugString() << "]"; | ||||
| @@ -39,9 +39,10 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| std::vector<AnfNodePtr> make_tuple_inputs; | std::vector<AnfNodePtr> make_tuple_inputs; | ||||
| AbstractBasePtrList abstract_list; | AbstractBasePtrList abstract_list; | ||||
| make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||||
| make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); | |||||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | ||||
| for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(cnode); ++output_idx) { | |||||
| size_t out_num = AnfAlgo::GetOutputTensorNum(cnode); | |||||
| for (size_t output_idx = 0; output_idx < out_num; ++output_idx) { | |||||
| AnfNodePtr replace_node = nullptr; | AnfNodePtr replace_node = nullptr; | ||||
| const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); | const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); | ||||
| const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx); | const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx); | ||||
| @@ -74,7 +75,7 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo | |||||
| } else { | } else { | ||||
| replace_node = getitem; | replace_node = getitem; | ||||
| } | } | ||||
| abstract_list.push_back(replace_node->abstract()); | |||||
| abstract_list.emplace_back(replace_node->abstract()); | |||||
| make_tuple_inputs.push_back(replace_node); | make_tuple_inputs.push_back(replace_node); | ||||
| } | } | ||||
| AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); | AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); | ||||
| @@ -27,7 +27,7 @@ AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_ | |||||
| MS_EXCEPTION_IF_NULL(origin_addn_cnode); | MS_EXCEPTION_IF_NULL(origin_addn_cnode); | ||||
| std::vector<AnfNodePtr> new_addn_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimAddN->name()))}; | std::vector<AnfNodePtr> new_addn_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimAddN->name()))}; | ||||
| for (size_t i = begin_index; i < begin_index + offset; ++i) { | for (size_t i = begin_index; i < begin_index + offset; ++i) { | ||||
| new_addn_inputs.push_back(origin_addn_cnode->input(i)); | |||||
| new_addn_inputs.emplace_back(origin_addn_cnode->input(i)); | |||||
| } | } | ||||
| CNodePtr new_addn = func_graph->NewCNode(new_addn_inputs); | CNodePtr new_addn = func_graph->NewCNode(new_addn_inputs); | ||||
| MS_EXCEPTION_IF_NULL(new_addn); | MS_EXCEPTION_IF_NULL(new_addn); | ||||
| @@ -66,7 +66,7 @@ const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfN | |||||
| cur_input_index += inputs_divisor_; | cur_input_index += inputs_divisor_; | ||||
| } | } | ||||
| for (size_t i = cur_input_index; i <= origin_input_size; i++) { | for (size_t i = cur_input_index; i <= origin_input_size; i++) { | ||||
| base_addn_inputs.push_back(new_cnode->input(i)); | |||||
| base_addn_inputs.emplace_back(new_cnode->input(i)); | |||||
| } | } | ||||
| CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs); | CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs); | ||||
| MS_EXCEPTION_IF_NULL(base_addn); | MS_EXCEPTION_IF_NULL(base_addn); | ||||
| @@ -37,7 +37,7 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s | |||||
| } | } | ||||
| size_t output_num = 0; | size_t output_num = 0; | ||||
| for (const auto &node_index : manager->node_users()[bn]) { | for (const auto &node_index : manager->node_users()[bn]) { | ||||
| AnfNodePtr output = node_index.first; | |||||
| const AnfNodePtr &output = node_index.first; | |||||
| MS_EXCEPTION_IF_NULL(output); | MS_EXCEPTION_IF_NULL(output); | ||||
| if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { | if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { | ||||
| continue; | continue; | ||||
| @@ -32,7 +32,7 @@ bool CheckOutputsIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| for (const auto &node_index : manager->node_users()[node]) { | for (const auto &node_index : manager->node_users()[node]) { | ||||
| AnfNodePtr output = node_index.first; | |||||
| const AnfNodePtr &output = node_index.first; | |||||
| MS_EXCEPTION_IF_NULL(output); | MS_EXCEPTION_IF_NULL(output); | ||||
| if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { | if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { | ||||
| continue; | continue; | ||||
| @@ -33,7 +33,7 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra | |||||
| std::vector<AnfNodePtr> *bn_update_grad_outputs) { | std::vector<AnfNodePtr> *bn_update_grad_outputs) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(bn_grad_node); | MS_EXCEPTION_IF_NULL(bn_grad_node); | ||||
| auto bn_grad_inputs = bn_grad_node->inputs(); | |||||
| const auto &bn_grad_inputs = bn_grad_node->inputs(); | |||||
| if (bn_grad_inputs.size() < kBNGradInputNum) { | if (bn_grad_inputs.size() < kBNGradInputNum) { | ||||
| MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; | MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; | ||||
| } | } | ||||
| @@ -58,7 +58,8 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra | |||||
| std::vector<AnfNodePtr> *bn_reduce_grad_outputs) { | std::vector<AnfNodePtr> *bn_reduce_grad_outputs) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(bn_grad_node); | MS_EXCEPTION_IF_NULL(bn_grad_node); | ||||
| auto bn_grad_inputs = bn_grad_node->inputs(); | |||||
| MS_EXCEPTION_IF_NULL(bn_reduce_grad_outputs); | |||||
| const auto &bn_grad_inputs = bn_grad_node->inputs(); | |||||
| if (bn_grad_inputs.size() < kBNGradInputNum) { | if (bn_grad_inputs.size() < kBNGradInputNum) { | ||||
| MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; | MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; | ||||
| } | } | ||||
| @@ -25,9 +25,9 @@ AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origi | |||||
| size_t offset) { | size_t offset) { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(origin_concat_cnode); | MS_EXCEPTION_IF_NULL(origin_concat_cnode); | ||||
| std::vector<AnfNodePtr> new_concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))}; | |||||
| std::vector<AnfNodePtr> new_concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))}; | |||||
| for (size_t i = begin_index; i < begin_index + offset; ++i) { | for (size_t i = begin_index; i < begin_index + offset; ++i) { | ||||
| new_concat_inputs.push_back(origin_concat_cnode->input(i)); | |||||
| new_concat_inputs.emplace_back(origin_concat_cnode->input(i)); | |||||
| } | } | ||||
| CNodePtr new_concat = func_graph->NewCNode(new_concat_inputs); | CNodePtr new_concat = func_graph->NewCNode(new_concat_inputs); | ||||
| MS_EXCEPTION_IF_NULL(new_concat); | MS_EXCEPTION_IF_NULL(new_concat); | ||||
| @@ -83,7 +83,7 @@ const AnfNodePtr ConcatFission::Process(const FuncGraphPtr &func_graph, const An | |||||
| cur_input_index += inputs_divisor_; | cur_input_index += inputs_divisor_; | ||||
| } | } | ||||
| for (size_t i = cur_input_index; i <= origin_input_size; i++) { | for (size_t i = cur_input_index; i <= origin_input_size; i++) { | ||||
| base_concat_inputs.push_back(new_cnode->input(i)); | |||||
| base_concat_inputs.emplace_back(new_cnode->input(i)); | |||||
| } | } | ||||
| CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs); | CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs); | ||||
| MS_EXCEPTION_IF_NULL(base_concat); | MS_EXCEPTION_IF_NULL(base_concat); | ||||
| @@ -31,9 +31,8 @@ void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars | |||||
| MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum; | MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSquareSumAllOpName))}; | |||||
| inputs.push_back(lars_v2->input(1)); | |||||
| inputs.push_back(lars_v2->input(2)); | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSquareSumAllOpName)), lars_v2->input(1), | |||||
| lars_v2->input(2)}; | |||||
| auto square_sum_all = graph->NewCNode(inputs); | auto square_sum_all = graph->NewCNode(inputs); | ||||
| MS_EXCEPTION_IF_NULL(square_sum_all); | MS_EXCEPTION_IF_NULL(square_sum_all); | ||||
| square_sum_all->set_scope(lars_v2->scope()); | square_sum_all->set_scope(lars_v2->scope()); | ||||
| @@ -56,13 +55,13 @@ CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2, | |||||
| if (lars_v2->size() != kLarsV2InputNum) { | if (lars_v2->size() != kLarsV2InputNum) { | ||||
| MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum; | MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kLarsV2UpdateOpName))}; | |||||
| inputs.push_back(lars_v2->input(1)); | |||||
| inputs.push_back(lars_v2->input(2)); | |||||
| inputs.push_back(square_sum_all_outputs[0]); | |||||
| inputs.push_back(square_sum_all_outputs[1]); | |||||
| inputs.push_back(lars_v2->input(3)); | |||||
| inputs.push_back(lars_v2->input(4)); | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kLarsV2UpdateOpName)), | |||||
| lars_v2->input(1), | |||||
| lars_v2->input(2), | |||||
| square_sum_all_outputs[0], | |||||
| square_sum_all_outputs[1], | |||||
| lars_v2->input(3), | |||||
| lars_v2->input(4)}; | |||||
| auto lars_v2_update = graph->NewCNode(inputs); | auto lars_v2_update = graph->NewCNode(inputs); | ||||
| MS_EXCEPTION_IF_NULL(lars_v2_update); | MS_EXCEPTION_IF_NULL(lars_v2_update); | ||||
| lars_v2_update->set_scope(lars_v2->scope()); | lars_v2_update->set_scope(lars_v2->scope()); | ||||
| @@ -32,6 +32,7 @@ void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop( | |||||
| std::vector<AnfNodePtr> *layer_norm_x_backprop_outputs) const { | std::vector<AnfNodePtr> *layer_norm_x_backprop_outputs) const { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(layer_norm_grad); | MS_EXCEPTION_IF_NULL(layer_norm_grad); | ||||
| MS_EXCEPTION_IF_NULL(layer_norm_x_backprop_outputs); | |||||
| auto prim = std::make_shared<Primitive>(kLayerNormXBackpropOpName); | auto prim = std::make_shared<Primitive>(kLayerNormXBackpropOpName); | ||||
| std::vector<AnfNodePtr> layer_norm_x_backprop_inputs = {NewValueNode(prim)}; | std::vector<AnfNodePtr> layer_norm_x_backprop_inputs = {NewValueNode(prim)}; | ||||
| for (size_t i = 1; i < layer_norm_grad->inputs().size(); ++i) { | for (size_t i = 1; i < layer_norm_grad->inputs().size(); ++i) { | ||||
| @@ -83,11 +83,11 @@ const AnfNodePtr PackFission::Process(const FuncGraphPtr &func_graph, const AnfN | |||||
| size_t cur_input_index = 1; | size_t cur_input_index = 1; | ||||
| // Divide the inputs of pack by inputs_divisor_. | // Divide the inputs of pack by inputs_divisor_. | ||||
| while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { | while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { | ||||
| base_concat_inputs.push_back(CreateNewPack(func_graph, cnode, cur_input_index, inputs_divisor_)); | |||||
| base_concat_inputs.emplace_back(CreateNewPack(func_graph, cnode, cur_input_index, inputs_divisor_)); | |||||
| cur_input_index += inputs_divisor_; | cur_input_index += inputs_divisor_; | ||||
| } | } | ||||
| if (cur_input_index <= origin_input_size) { | if (cur_input_index <= origin_input_size) { | ||||
| base_concat_inputs.push_back( | |||||
| base_concat_inputs.emplace_back( | |||||
| CreateNewPack(func_graph, cnode, cur_input_index, origin_input_size - cur_input_index + 1)); | CreateNewPack(func_graph, cnode, cur_input_index, origin_input_size - cur_input_index + 1)); | ||||
| } | } | ||||
| @@ -96,17 +96,16 @@ void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int | |||||
| void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv, | void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv, | ||||
| const std::vector<int> &size_splits_base, int split_dim, int num_split) { | const std::vector<int> &size_splits_base, int split_dim, int num_split) { | ||||
| SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split); | SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split); | ||||
| std::vector<TypeId> base_type_ids; | |||||
| std::vector<std::vector<size_t>> base_output_shapes_base; | |||||
| auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); | auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); | ||||
| TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); | TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); | ||||
| std::vector<TypeId> base_type_ids(num_split, type_id); | |||||
| std::vector<std::vector<size_t>> base_output_shapes_base; | |||||
| if (split_dim < 0) { | if (split_dim < 0) { | ||||
| split_dim += output_shape.size(); | split_dim += output_shape.size(); | ||||
| } | } | ||||
| for (int i = 0; i < num_split; ++i) { | for (int i = 0; i < num_split; ++i) { | ||||
| output_shape[split_dim] = size_splits_base[i]; | output_shape[split_dim] = size_splits_base[i]; | ||||
| base_output_shapes_base.emplace_back(output_shape); | base_output_shapes_base.emplace_back(output_shape); | ||||
| base_type_ids.emplace_back(type_id); | |||||
| } | } | ||||
| AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get()); | AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get()); | ||||
| } | } | ||||
| @@ -118,17 +117,14 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int | |||||
| // Create new size_splits for "size_splits" attr of each new Splitv node which has full inputs. | // Create new size_splits for "size_splits" attr of each new Splitv node which has full inputs. | ||||
| auto small_split_size = SizeToInt(GetSmallSplitSize(cnode, split_dim, num_split)); | auto small_split_size = SizeToInt(GetSmallSplitSize(cnode, split_dim, num_split)); | ||||
| std::vector<int> size_splits_new; | |||||
| for (int i = 0; i < divisor; ++i) { | |||||
| size_splits_new.emplace_back(small_split_size); | |||||
| } | |||||
| std::vector<int> size_splits_new(divisor, small_split_size); | |||||
| // Create new output shape and new output type id for each new Splitv node which has full inputs. | // Create new output shape and new output type id for each new Splitv node which has full inputs. | ||||
| std::vector<TypeId> new_type_ids; | std::vector<TypeId> new_type_ids; | ||||
| std::vector<std::vector<size_t>> new_output_shapes; | std::vector<std::vector<size_t>> new_output_shapes; | ||||
| CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, divisor, &new_type_ids, &new_output_shapes); | CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, divisor, &new_type_ids, &new_output_shapes); | ||||
| // Create make_tuple input to create a make_tuple for replacing the old Split node. | // Create make_tuple input to create a make_tuple for replacing the old Split node. | ||||
| std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)}; | |||||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | |||||
| // Start to divide the outputs of Split. | // Start to divide the outputs of Split. | ||||
| std::vector<int> size_splits_base; | std::vector<int> size_splits_base; | ||||
| const auto base_split_size = divisor * small_split_size; | const auto base_split_size = divisor * small_split_size; | ||||
| @@ -147,10 +143,7 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int | |||||
| auto last_node_num_split = num_split - cur_output_index; | auto last_node_num_split = num_split - cur_output_index; | ||||
| if (last_node_num_split > 1) { | if (last_node_num_split > 1) { | ||||
| CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num)); | CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num)); | ||||
| std::vector<int> size_splits_new_last; | |||||
| for (int i = 0; i < last_node_num_split; ++i) { | |||||
| size_splits_new_last.emplace_back(small_split_size); | |||||
| } | |||||
| std::vector<int> size_splits_new_last(last_node_num_split, small_split_size); | |||||
| SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split); | SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split); | ||||
| // Create new output shape and new output type id for the last Splitv node | // Create new output shape and new output type id for the last Splitv node | ||||
| std::vector<TypeId> last_new_type_ids; | std::vector<TypeId> last_new_type_ids; | ||||
| @@ -44,7 +44,7 @@ void GetBNOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vect | |||||
| MS_LOG(EXCEPTION) << "The bn node " << bn->DebugString() << " should has some outputs"; | MS_LOG(EXCEPTION) << "The bn node " << bn->DebugString() << " should has some outputs"; | ||||
| } | } | ||||
| for (const auto &node_index : manager->node_users()[bn]) { | for (const auto &node_index : manager->node_users()[bn]) { | ||||
| AnfNodePtr output = node_index.first; | |||||
| const AnfNodePtr &output = node_index.first; | |||||
| MS_EXCEPTION_IF_NULL(output); | MS_EXCEPTION_IF_NULL(output); | ||||
| bn_outputs->push_back(output); | bn_outputs->push_back(output); | ||||
| } | } | ||||
| @@ -313,9 +313,9 @@ void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNod | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(outputs); | MS_EXCEPTION_IF_NULL(outputs); | ||||
| for (size_t i = 0; i < output_num; i++) { | for (size_t i = 0; i < output_num; i++) { | ||||
| auto idx = NewValueNode(SizeToInt(i)); | |||||
| MS_EXCEPTION_IF_NULL(idx); | |||||
| int temp = SizeToInt(i); | int temp = SizeToInt(i); | ||||
| auto idx = NewValueNode(temp); | |||||
| MS_EXCEPTION_IF_NULL(idx); | |||||
| auto imm = std::make_shared<Int32Imm>(temp); | auto imm = std::make_shared<Int32Imm>(temp); | ||||
| auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm); | auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm); | ||||
| idx->set_abstract(abstract_scalar); | idx->set_abstract(abstract_scalar); | ||||
| @@ -745,7 +745,7 @@ void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> & | |||||
| for (size_t i = 0; i < types.size(); ++i) { | for (size_t i = 0; i < types.size(); ++i) { | ||||
| std::vector<int> shape_int; | std::vector<int> shape_int; | ||||
| std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(shape_int), SizeToInt); | std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(shape_int), SizeToInt); | ||||
| abstract_list.push_back(std::make_shared<AbstractTensor>(TypeIdToType(types[i]), shape_int)); | |||||
| abstract_list.emplace_back(std::make_shared<AbstractTensor>(TypeIdToType(types[i]), shape_int)); | |||||
| } | } | ||||
| auto abstract_tuple = std::make_shared<AbstractTuple>(abstract_list); | auto abstract_tuple = std::make_shared<AbstractTuple>(abstract_list); | ||||
| node->set_abstract(abstract_tuple); | node->set_abstract(abstract_tuple); | ||||
| @@ -550,6 +550,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kern | |||||
| kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type); | kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type); | ||||
| auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); | auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); | ||||
| // If aicore not find valid kernel info reloading aicpu kernel info list to find it | // If aicore not find valid kernel info reloading aicpu kernel info list to find it | ||||
| if (select_status == kNoMatched) { | if (select_status == kNoMatched) { | ||||
| MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() | MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() | ||||
| << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; | << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <unistd.h> | #include <unistd.h> | ||||
| #include <fcntl.h> | #include <fcntl.h> | ||||
| #include <sys/stat.h> | #include <sys/stat.h> | ||||
| #include <sys/time.h> | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <set> | #include <set> | ||||
| @@ -337,5 +338,33 @@ static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { | |||||
| MS_LOG(DEBUG) << "File `" << file_name << "` change mode failed! May be not exist."; | MS_LOG(DEBUG) << "File `" << file_name << "` change mode failed! May be not exist."; | ||||
| } | } | ||||
| } | } | ||||
| static inline uint64_t GetCurrentUSec() { | |||||
| struct timeval tv; | |||||
| int ret = gettimeofday(&tv, nullptr); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "Fail gettimeofday, ret = " << ret; | |||||
| } | |||||
| return static_cast<uint64_t>(tv.tv_usec + tv.tv_sec * 1000000); | |||||
| } | |||||
| #define PROF_START(stage) uint64_t start_usec_##stage = mindspore::GetCurrentUSec() | |||||
| #define PROF_END(stage) \ | |||||
| do { \ | |||||
| uint64_t end_usec_##stage = mindspore::GetCurrentUSec(); \ | |||||
| MS_LOG(INFO) << #stage << " costs " << (end_usec_##stage - start_usec_##stage) << " usec."; \ | |||||
| } while (0) | |||||
| #define PROF_MULTI_DEFINE(stage) \ | |||||
| static uint64_t total_##stage = 0; \ | |||||
| static uint64_t count_##stage = 0; | |||||
| #define PROF_MULTI_START(stage) uint64_t start_usec_##stage = mindspore::GetCurrentUSec() | |||||
| #define PROF_MULTI_END(stage) \ | |||||
| ++count_##stage; \ | |||||
| uint64_t end_usec_##stage = mindspore::GetCurrentUSec(); \ | |||||
| total_##stage += (end_usec_##stage - start_usec_##stage) | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_UTILS_UTILS_H_ | #endif // MINDSPORE_CCSRC_UTILS_UTILS_H_ | ||||