From: @lianliguang Reviewed-by: @chujinjin Signed-off-by: @chujinjintags/v1.1.0
| @@ -195,6 +195,7 @@ void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_g | |||||
| auto data_layout_pm = std::make_shared<PassManager>("pynative_transop_pm"); | auto data_layout_pm = std::make_shared<PassManager>("pynative_transop_pm"); | ||||
| data_layout_pm->AddPass(std::make_shared<ChangeAxisOfReduceKernel>()); | data_layout_pm->AddPass(std::make_shared<ChangeAxisOfReduceKernel>()); | ||||
| data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>()); | data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>()); | ||||
| data_layout_pm->AddPass(std::make_shared<DynamicRNNGradReformat>()); | |||||
| data_layout_pm->AddPass(std::make_shared<RunOpInsertTransData>()); | data_layout_pm->AddPass(std::make_shared<RunOpInsertTransData>()); | ||||
| data_layout_pm->AddPass(std::make_shared<GetitemTuple>()); | data_layout_pm->AddPass(std::make_shared<GetitemTuple>()); | ||||
| data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | ||||
| @@ -338,7 +339,9 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne | |||||
| ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>()); | ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>()); | ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<DynamicGRUV2GradFission>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>()); | ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<DynamicRnnGradFissionV2>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>()); | ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>()); | ||||
| optimizer->AddPassManager(ir_fusion_pm); | optimizer->AddPassManager(ir_fusion_pm); | ||||
| @@ -526,13 +526,6 @@ bool TransDataType(const TypeIdArgs &args, void *result) { | |||||
| } | } | ||||
| bool TransFormat(const FormatArgs &args, void *result) { | bool TransFormat(const FormatArgs &args, void *result) { | ||||
| using FormatTransfer = std::function<bool(const FormatArgs &, void *)>; | |||||
| const std::map<std::string, FormatTransfer> format_trans_map{ | |||||
| {kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz}, | |||||
| {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, | |||||
| {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}, | |||||
| {kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}}; | |||||
| MS_LOG(DEBUG) << "Start trans format."; | MS_LOG(DEBUG) << "Start trans format."; | ||||
| if (abstract::TypeIdSize(args.src_data_type) < 1) { | if (abstract::TypeIdSize(args.src_data_type) < 1) { | ||||
| MS_LOG(ERROR) << "Invalid datatype.."; | MS_LOG(ERROR) << "Invalid datatype.."; | ||||
| @@ -541,15 +534,14 @@ bool TransFormat(const FormatArgs &args, void *result) { | |||||
| if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) { | if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) { | ||||
| return NchwTo4D(args, result); | return NchwTo4D(args, result); | ||||
| } | } | ||||
| auto iter = format_trans_map.find(args.device_format); | |||||
| if (iter == format_trans_map.end()) { | |||||
| auto iter = kTransFormatMapOfHostToDevice.find(args.device_format); | |||||
| if (iter == kTransFormatMapOfHostToDevice.end()) { | |||||
| MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]"; | MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]"; | ||||
| } | } | ||||
| return iter->second(args, result); | return iter->second(args, result); | ||||
| } | } | ||||
| bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { | bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { | ||||
| using FormatTransfer = std::function<bool(const FormatArgs &, void *)>; | |||||
| const std::map<std::string, FormatTransfer> format_trans_map{ | const std::map<std::string, FormatTransfer> format_trans_map{ | ||||
| {kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw}, | {kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw}, | ||||
| {kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, | {kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, | ||||
| @@ -76,6 +76,13 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result); | |||||
| bool Nc1hwc04ToNchw(const FormatArgs &args, void *result); | bool Nc1hwc04ToNchw(const FormatArgs &args, void *result); | ||||
| bool C1hwncoc0ToNchw(const FormatArgs &args, void *result); | bool C1hwncoc0ToNchw(const FormatArgs &args, void *result); | ||||
| bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result); | bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result); | ||||
| using FormatTransfer = std::function<bool(const FormatArgs &, void *)>; | |||||
| const std::map<std::string, FormatTransfer> kTransFormatMapOfHostToDevice{ | |||||
| {kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz}, | |||||
| {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, | |||||
| {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}, | |||||
| {kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}}; | |||||
| } // namespace trans | } // namespace trans | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -29,6 +29,7 @@ | |||||
| #include "backend/kernel_compiler/oplib/oplib.h" | #include "backend/kernel_compiler/oplib/oplib.h" | ||||
| #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" | #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" | ||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "common/trans.h" | |||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| @@ -382,14 +383,15 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | ||||
| std::vector<std::string> output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)}; | |||||
| auto refresh_format = selected_kernel_info->GetInputFormat(input_index); | |||||
| std::vector<std::string> output_format = {refresh_format}; | |||||
| // if not find in host convert format map means the host has not registered the convert function of this format | |||||
| if (trans::kTransFormatMapOfHostToDevice.find(refresh_format) == trans::kTransFormatMapOfHostToDevice.end() && | |||||
| refresh_format != kOpFormat_DEFAULT) { | |||||
| output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)}; | |||||
| } | |||||
| if (IsValueNode<tensor::Tensor>(input_kernel_node) && | if (IsValueNode<tensor::Tensor>(input_kernel_node) && | ||||
| AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { | AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { | ||||
| if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM || | |||||
| selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_Z_3D || | |||||
| selected_kernel_info->GetInputFormat(input_index) != kOpFormat_NDC1HWC0) { | |||||
| output_format = {selected_kernel_info->GetInputFormat(input_index)}; | |||||
| } | |||||
| builder->SetOutputsFormat(output_format); | builder->SetOutputsFormat(output_format); | ||||
| std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)}; | std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)}; | ||||
| builder->SetOutputsDeviceType(output_type); | builder->SetOutputsDeviceType(output_type); | ||||
| @@ -397,11 +399,6 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { | if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { | ||||
| if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM || | |||||
| selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_Z_3D || | |||||
| selected_kernel_info->GetInputFormat(input_index) != kOpFormat_NDC1HWC0) { | |||||
| output_format = {selected_kernel_info->GetInputFormat(input_index)}; | |||||
| } | |||||
| builder->SetOutputsFormat(output_format); | builder->SetOutputsFormat(output_format); | ||||
| std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)}; | std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)}; | ||||
| builder->SetOutputsDeviceType(output_type); | builder->SetOutputsDeviceType(output_type); | ||||