Merge pull request !1562 from lianliguang/r0.3tags/v0.3.0-alpha
| @@ -166,20 +166,10 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co | |||
| std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder = | |||
| std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| // we set special device info of a input tensor. | |||
| bool is_ref = false; | |||
| auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE); | |||
| if (op_info != nullptr) { | |||
| is_ref = op_info->is_ref(); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | |||
| if (MsContext::GetInstance()->execution_mode() == kPynativeMode && | |||
| AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { | |||
| if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) { | |||
| std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)}; | |||
| builder->SetOutputsFormat(output_format); | |||
| std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; | |||
| std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)}; | |||
| builder->SetOutputsDeviceType(output_type); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); | |||
| } | |||
| @@ -383,6 +383,11 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp | |||
| return false; | |||
| } | |||
| std::vector<Axis> reshape_type; | |||
| if (!StringToAxisVector(input->reshape_type(), &reshape_type)) { | |||
| return false; | |||
| } | |||
| if (param_type == "dynamic") { | |||
| if (dyn_input_sizes.empty()) { | |||
| MS_LOG(ERROR) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic"; | |||
| @@ -394,6 +399,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp | |||
| auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); | |||
| inputs_device_type.push_back(type_id); | |||
| inputs_format.push_back(formats[builder_idex]); | |||
| reshape_types.push_back(reshape_type); | |||
| } | |||
| dyn_input_idx++; | |||
| } else if (param_type == "required") { | |||
| @@ -401,6 +407,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp | |||
| auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); | |||
| inputs_device_type.push_back(type_id); | |||
| inputs_format.push_back(formats[builder_idex]); | |||
| reshape_types.push_back(reshape_type); | |||
| } else { | |||
| if (kernel_info_index < real_input_num) { | |||
| MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is " << kernel_info_index; | |||
| @@ -408,13 +415,9 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp | |||
| auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); | |||
| inputs_device_type.push_back(type_id); | |||
| inputs_format.push_back(formats[builder_idex]); | |||
| reshape_types.push_back(reshape_type); | |||
| } | |||
| } | |||
| std::vector<Axis> reshape_type; | |||
| if (!StringToAxisVector(input->reshape_type(), &reshape_type)) { | |||
| return false; | |||
| } | |||
| reshape_types.push_back(reshape_type); | |||
| } | |||
| builder->SetInputReshapeType(reshape_types); | |||
| @@ -442,6 +445,11 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou | |||
| MS_LOG(WARNING) << "real_output_num: " << real_output_num << ", output_idx: " << output_idx << "is out of limit!"; | |||
| continue; | |||
| } | |||
| std::vector<Axis> reshape_type; | |||
| if (!StringToAxisVector(output->reshape_type(), &reshape_type)) { | |||
| return false; | |||
| } | |||
| size_t output_num = 0; | |||
| if (output->param_type() == "dynamic") { | |||
| if (outputs.size() > 1) { | |||
| @@ -467,12 +475,9 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou | |||
| auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); | |||
| outputs_device_type.push_back(type_id); | |||
| outputs_format.push_back(formats[builder_idex]); | |||
| reshape_types.push_back(reshape_type); | |||
| output_idx++; | |||
| } | |||
| std::vector<Axis> reshape_type; | |||
| if (!StringToAxisVector(output->reshape_type(), &reshape_type)) { | |||
| return false; | |||
| } | |||
| reshape_types.push_back(reshape_type); | |||
| } | |||
| @@ -33,12 +33,15 @@ using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; | |||
| namespace { | |||
| kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, | |||
| const AnfNodePtr &node, const TypeId device_type, | |||
| const kernel::KernelBuildInfo &ori_build_info) { | |||
| const kernel::KernelBuildInfo &ori_build_info, | |||
| const std::vector<kernel::Axis> &reshape_type) { | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsFormat({input_format}); | |||
| builder.SetOutputsFormat({output_format}); | |||
| builder.SetInputsDeviceType({device_type}); | |||
| builder.SetOutputsDeviceType({device_type}); | |||
| builder.SetOutputReshapeType({reshape_type}); | |||
| builder.SetInputReshapeType({reshape_type}); | |||
| builder.SetKernelType(ori_build_info.kernel_type()); | |||
| builder.SetFusionType(ori_build_info.fusion_type()); | |||
| builder.SetProcessor(ori_build_info.processor()); | |||
| @@ -175,6 +178,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| AnfNodePtr trans_node = nullptr; | |||
| AnfNodePtr input_node = node; | |||
| AnfNodePtr trans_data = nullptr; | |||
| std::vector<kernel::Axis> reshape_type = AnfAlgo::GetOutputReshapeType(node, 0); | |||
| TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (origin_format.empty() || dest_format.empty()) { | |||
| @@ -189,6 +193,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| input_node = AnfAlgo::GetInputNode(cnode, insert_index); | |||
| reshape_type = AnfAlgo::GetInputReshapeType(node, insert_index); | |||
| } | |||
| bool need_padding = false; | |||
| if (is_insert_input) { | |||
| @@ -222,7 +227,8 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| MS_EXCEPTION_IF_NULL(trans_data); | |||
| MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); | |||
| auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); | |||
| auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, dtype, *trans_ori_build_info); | |||
| auto kernel_build_info = | |||
| RefreshKernelBuildInfo(origin_format, dest_format, input_node, dtype, *trans_ori_build_info, reshape_type); | |||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get()); | |||
| return trans_node; | |||
| } | |||
| @@ -309,9 +315,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod | |||
| auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); | |||
| auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); | |||
| auto is_weight_boundary = [](const AnfNodePtr &node) -> bool { | |||
| if (node->isa<ValueNode>()) { | |||
| return true; | |||
| } else if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) { | |||
| if (node->isa<ValueNode>() || node->isa<Parameter>()) { | |||
| return true; | |||
| } | |||
| return false; | |||