| @@ -96,11 +96,13 @@ class BiasAddGpuKernel : public GpuKernel { | |||
| b_dims[i] = (i == pos) ? SizeToInt(x_shape[i]) : 1; | |||
| } | |||
| auto input_device_format = AnfAlgo::GetInputFormat(kernel_node, 0); | |||
| auto cudnn_cal_format = (input_device_format == kOpFormat_NHWC) ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensorNdDescriptorEx(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), x_dims.get()), | |||
| cudnnSetTensorNdDescriptorEx(x_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), x_dims.get()), | |||
| "cudnnSetTensorNdDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensorNdDescriptorEx(b_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), b_dims.get()), | |||
| cudnnSetTensorNdDescriptorEx(b_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), b_dims.get()), | |||
| "cudnnSetTensorNdDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetOpTensorDescriptor(op_desc_, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN), | |||
| @@ -94,11 +94,13 @@ class BiasAddGradGpuKernel : public GpuKernel { | |||
| } | |||
| } | |||
| auto input_device_format = AnfAlgo::GetInputFormat(kernel_node, 0); | |||
| auto cudnn_cal_format = (input_device_format == kOpFormat_NHWC) ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensorNdDescriptorEx(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), dy_dims.get()), | |||
| cudnnSetTensorNdDescriptorEx(dy_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), dy_dims.get()), | |||
| "cudnnSetTensorNdDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensorNdDescriptorEx(db_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), db_dims.get()), | |||
| cudnnSetTensorNdDescriptorEx(db_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), db_dims.get()), | |||
| "cudnnSetTensorNdDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN, | |||
| @@ -18,6 +18,7 @@ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/utils.h" | |||
| @@ -38,21 +39,22 @@ std::vector<int> TransposeAxis(const std::string &src_format, const std::string | |||
| } | |||
| // Transpose can be replaceed by nop reshape in some situations. | |||
| // 1. out_shape [x, 1, 1, y] with transpose perm {0, 2, 3, 1} | |||
| // 2. out_shape [x, y, 1, 1] with transpose perm {0, 3, 1, 2} | |||
| // 1. out_shape [x, 1, 1, y] | |||
| // 2. out_shape [x, y, 1, 1] | |||
| // 3. out_shape [x, 1, y, 1] | |||
| bool IsFakeTranspose(const std::vector<size_t> &out_shape, const std::vector<int> &transpose_perm) { | |||
| if (out_shape.size() != 4) { | |||
| MS_LOG(EXCEPTION) << "Invalid data shape, 4-D data was needed, but get " << out_shape.size() << "-D."; | |||
| } | |||
| std::vector<int> perm1 = {0, 2, 3, 1}; | |||
| std::vector<int> perm2 = {0, 3, 1, 2}; | |||
| if (transpose_perm == perm1) { | |||
| return (out_shape[1] == 1 && out_shape[2] == 1); | |||
| } else if (transpose_perm == perm2) { | |||
| return (out_shape[2] == 1 && out_shape[3] == 1); | |||
| } else { | |||
| return false; | |||
| auto num = std::count(out_shape.begin(), out_shape.end(), 1); | |||
| if ((transpose_perm == perm1) || (transpose_perm == perm2)) { | |||
| if (num >= 2) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| void SetTransposeOpBuildInfo(const std::string &input_format, const std::string &output_format, | |||
| @@ -73,6 +75,8 @@ void SetTransposeOpBuildInfo(const std::string &input_format, const std::string | |||
| // Insert transpose op between node and used_node whose position is used_node_index. | |||
| CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node, | |||
| int used_node_index, const std::vector<int> &transpose_perm) { | |||
| MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope() | |||
| << ", index: " << used_node_index; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| // 0.Judge whether it is a fake transpose | |||
| auto transed_shape = AnfAlgo::GetInputDeviceShape(used_node, used_node_index); | |||
| @@ -95,15 +99,10 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co | |||
| if (!is_fake) { | |||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op); | |||
| } | |||
| // 4.Set the input of used_node. | |||
| MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope() | |||
| << ", index: " << used_node_index; | |||
| AnfAlgo::SetNodeInput(utils::cast<CNodePtr>(used_node), transpose_op, used_node_index); | |||
| // 5. Update the manager info of transpose op. | |||
| // 4. Set the new edge of transpose op. | |||
| FuncGraphManagerPtr manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->Clear(); | |||
| manager->AddFuncGraph(graph); | |||
| manager->SetEdge(used_node, used_node_index + 1, transpose_op); | |||
| return transpose_op; | |||
| } | |||
| } // namespace | |||
| @@ -252,11 +252,11 @@ void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<s | |||
| bn_cnt++; | |||
| } | |||
| } | |||
| if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) { | |||
| format_transform_ = false; | |||
| if (conv_cnt + bn_cnt > 1) { | |||
| format_transform_ = true; | |||
| return; | |||
| } | |||
| format_transform_ = true; | |||
| format_transform_ = false; | |||
| } | |||
| void SetKernelInfo(const CNodePtr &kernel_node) { | |||
| @@ -34,23 +34,27 @@ namespace gpu { | |||
| // map<opName, (inputFormatPosition, outputFormatPosition)>, used for getting the insert position of format transform. | |||
| // If input position is empty, then insert all the input positions, because the input numbers of this op are variable. | |||
| static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>> kKernelFormatPositionMap = { | |||
| // Format sensitive. | |||
| {prim::kPrimConv2D->name(), {{0, 1}, {0}}}, | |||
| {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {0}}}, | |||
| {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {0}}}, | |||
| {prim::kPrimRelu->name(), {{0}, {0}}}, | |||
| {prim::kPrimReluGrad->name(), {{0, 1}, {0}}}, | |||
| {prim::kPrimMaxPool->name(), {{0}, {0}}}, | |||
| {prim::kPrimMaxPoolGrad->name(), {{0, 1, 2}, {0}}}, | |||
| {kSliceOpName, {{0}, {0}}}, | |||
| {kAvgPoolOpName, {{0}, {0}}}, | |||
| {kAvgPoolGradGpuOpName, {{0, 1, 2}, {0}}}, | |||
| {kTensorAddOpName, {{0, 1}, {0}}}, | |||
| {kFusedBatchNormEx, {{0}, {0}}}, | |||
| {kFusedBatchNormExWithActivation, {{0}, {0}}}, | |||
| {kFusedBatchNormExWithAddAndActivation, {{0, 5}, {0}}}, | |||
| {kFusedBatchNormGradEx, {{0, 1}, {0}}}, | |||
| {kFusedBatchNormGradExWithActivation, {{0, 1, 7}, {0}}}, | |||
| {kFusedBatchNormGradExWithAddAndActivation, {{0, 1, 7}, {0, 3}}}, | |||
| {kBiasAddOpName, {{0}, {0}}}, | |||
| {prim::kPrimBiasAddGrad->name(), {{0}, {}}}, | |||
| // Format insensitive. | |||
| {prim::kPrimRelu->name(), {{0}, {0}}}, | |||
| {prim::kPrimReluGrad->name(), {{0, 1}, {0}}}, | |||
| {kSliceOpName, {{0}, {0}}}, | |||
| {kTensorAddOpName, {{0, 1}, {0}}}, | |||
| {prim::kPrimConcat->name(), {{}, {0}}}, | |||
| {prim::kPrimAddN->name(), {{}, {0}}}, | |||
| }; | |||
| @@ -74,8 +78,6 @@ class FormatTransformChecker { | |||
| FormatTransformChecker &operator=(const FormatTransformChecker &); | |||
| bool format_transform_{true}; | |||
| static constexpr size_t kConv2dCount = 96; | |||
| static constexpr size_t kFusedBatchNormCount = 94; | |||
| }; | |||
| class KernelAttr { | |||