| @@ -96,11 +96,13 @@ class BiasAddGpuKernel : public GpuKernel { | |||||
| b_dims[i] = (i == pos) ? SizeToInt(x_shape[i]) : 1; | b_dims[i] = (i == pos) ? SizeToInt(x_shape[i]) : 1; | ||||
| } | } | ||||
| auto input_device_format = AnfAlgo::GetInputFormat(kernel_node, 0); | |||||
| auto cudnn_cal_format = (input_device_format == kOpFormat_NHWC) ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW; | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| cudnnSetTensorNdDescriptorEx(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), x_dims.get()), | |||||
| cudnnSetTensorNdDescriptorEx(x_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), x_dims.get()), | |||||
| "cudnnSetTensorNdDescriptor failed"); | "cudnnSetTensorNdDescriptor failed"); | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| cudnnSetTensorNdDescriptorEx(b_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), b_dims.get()), | |||||
| cudnnSetTensorNdDescriptorEx(b_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), b_dims.get()), | |||||
| "cudnnSetTensorNdDescriptor failed"); | "cudnnSetTensorNdDescriptor failed"); | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| cudnnSetOpTensorDescriptor(op_desc_, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN), | cudnnSetOpTensorDescriptor(op_desc_, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN), | ||||
| @@ -94,11 +94,13 @@ class BiasAddGradGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| } | } | ||||
| auto input_device_format = AnfAlgo::GetInputFormat(kernel_node, 0); | |||||
| auto cudnn_cal_format = (input_device_format == kOpFormat_NHWC) ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW; | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| cudnnSetTensorNdDescriptorEx(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), dy_dims.get()), | |||||
| cudnnSetTensorNdDescriptorEx(dy_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), dy_dims.get()), | |||||
| "cudnnSetTensorNdDescriptor failed"); | "cudnnSetTensorNdDescriptor failed"); | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| cudnnSetTensorNdDescriptorEx(db_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), db_dims.get()), | |||||
| cudnnSetTensorNdDescriptorEx(db_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), db_dims.get()), | |||||
| "cudnnSetTensorNdDescriptor failed"); | "cudnnSetTensorNdDescriptor failed"); | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN, | cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN, | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <algorithm> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| @@ -38,21 +39,22 @@ std::vector<int> TransposeAxis(const std::string &src_format, const std::string | |||||
| } | } | ||||
| // Transpose can be replaceed by nop reshape in some situations. | // Transpose can be replaceed by nop reshape in some situations. | ||||
| // 1. out_shape [x, 1, 1, y] with transpose perm {0, 2, 3, 1} | |||||
| // 2. out_shape [x, y, 1, 1] with transpose perm {0, 3, 1, 2} | |||||
| // 1. out_shape [x, 1, 1, y] | |||||
| // 2. out_shape [x, y, 1, 1] | |||||
| // 3. out_shape [x, 1, y, 1] | |||||
| bool IsFakeTranspose(const std::vector<size_t> &out_shape, const std::vector<int> &transpose_perm) { | bool IsFakeTranspose(const std::vector<size_t> &out_shape, const std::vector<int> &transpose_perm) { | ||||
| if (out_shape.size() != 4) { | if (out_shape.size() != 4) { | ||||
| MS_LOG(EXCEPTION) << "Invalid data shape, 4-D data was needed, but get " << out_shape.size() << "-D."; | MS_LOG(EXCEPTION) << "Invalid data shape, 4-D data was needed, but get " << out_shape.size() << "-D."; | ||||
| } | } | ||||
| std::vector<int> perm1 = {0, 2, 3, 1}; | std::vector<int> perm1 = {0, 2, 3, 1}; | ||||
| std::vector<int> perm2 = {0, 3, 1, 2}; | std::vector<int> perm2 = {0, 3, 1, 2}; | ||||
| if (transpose_perm == perm1) { | |||||
| return (out_shape[1] == 1 && out_shape[2] == 1); | |||||
| } else if (transpose_perm == perm2) { | |||||
| return (out_shape[2] == 1 && out_shape[3] == 1); | |||||
| } else { | |||||
| return false; | |||||
| auto num = std::count(out_shape.begin(), out_shape.end(), 1); | |||||
| if ((transpose_perm == perm1) || (transpose_perm == perm2)) { | |||||
| if (num >= 2) { | |||||
| return true; | |||||
| } | |||||
| } | } | ||||
| return false; | |||||
| } | } | ||||
| void SetTransposeOpBuildInfo(const std::string &input_format, const std::string &output_format, | void SetTransposeOpBuildInfo(const std::string &input_format, const std::string &output_format, | ||||
| @@ -73,6 +75,8 @@ void SetTransposeOpBuildInfo(const std::string &input_format, const std::string | |||||
| // Insert transpose op between node and used_node whose position is used_node_index. | // Insert transpose op between node and used_node whose position is used_node_index. | ||||
| CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node, | CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node, | ||||
| int used_node_index, const std::vector<int> &transpose_perm) { | int used_node_index, const std::vector<int> &transpose_perm) { | ||||
| MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope() | |||||
| << ", index: " << used_node_index; | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| // 0.Judge whether it is a fake transpose | // 0.Judge whether it is a fake transpose | ||||
| auto transed_shape = AnfAlgo::GetInputDeviceShape(used_node, used_node_index); | auto transed_shape = AnfAlgo::GetInputDeviceShape(used_node, used_node_index); | ||||
| @@ -95,15 +99,10 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co | |||||
| if (!is_fake) { | if (!is_fake) { | ||||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op); | AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op); | ||||
| } | } | ||||
| // 4.Set the input of used_node. | |||||
| MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope() | |||||
| << ", index: " << used_node_index; | |||||
| AnfAlgo::SetNodeInput(utils::cast<CNodePtr>(used_node), transpose_op, used_node_index); | |||||
| // 5. Update the manager info of transpose op. | |||||
| // 4. Set the new edge of transpose op. | |||||
| FuncGraphManagerPtr manager = graph->manager(); | FuncGraphManagerPtr manager = graph->manager(); | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| manager->Clear(); | |||||
| manager->AddFuncGraph(graph); | |||||
| manager->SetEdge(used_node, used_node_index + 1, transpose_op); | |||||
| return transpose_op; | return transpose_op; | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -252,11 +252,11 @@ void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<s | |||||
| bn_cnt++; | bn_cnt++; | ||||
| } | } | ||||
| } | } | ||||
| if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) { | |||||
| format_transform_ = false; | |||||
| if (conv_cnt + bn_cnt > 1) { | |||||
| format_transform_ = true; | |||||
| return; | return; | ||||
| } | } | ||||
| format_transform_ = true; | |||||
| format_transform_ = false; | |||||
| } | } | ||||
| void SetKernelInfo(const CNodePtr &kernel_node) { | void SetKernelInfo(const CNodePtr &kernel_node) { | ||||
| @@ -34,23 +34,27 @@ namespace gpu { | |||||
| // map<opName, (inputFormatPosition, outputFormatPosition)>, used for getting the insert position of format transform. | // map<opName, (inputFormatPosition, outputFormatPosition)>, used for getting the insert position of format transform. | ||||
| // If input position is empty, then insert all the input positions, because the input numbers of this op are variable. | // If input position is empty, then insert all the input positions, because the input numbers of this op are variable. | ||||
| static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>> kKernelFormatPositionMap = { | static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>> kKernelFormatPositionMap = { | ||||
| // Format sensitive. | |||||
| {prim::kPrimConv2D->name(), {{0, 1}, {0}}}, | {prim::kPrimConv2D->name(), {{0, 1}, {0}}}, | ||||
| {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {0}}}, | {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {0}}}, | ||||
| {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {0}}}, | {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {0}}}, | ||||
| {prim::kPrimRelu->name(), {{0}, {0}}}, | |||||
| {prim::kPrimReluGrad->name(), {{0, 1}, {0}}}, | |||||
| {prim::kPrimMaxPool->name(), {{0}, {0}}}, | {prim::kPrimMaxPool->name(), {{0}, {0}}}, | ||||
| {prim::kPrimMaxPoolGrad->name(), {{0, 1, 2}, {0}}}, | {prim::kPrimMaxPoolGrad->name(), {{0, 1, 2}, {0}}}, | ||||
| {kSliceOpName, {{0}, {0}}}, | |||||
| {kAvgPoolOpName, {{0}, {0}}}, | {kAvgPoolOpName, {{0}, {0}}}, | ||||
| {kAvgPoolGradGpuOpName, {{0, 1, 2}, {0}}}, | {kAvgPoolGradGpuOpName, {{0, 1, 2}, {0}}}, | ||||
| {kTensorAddOpName, {{0, 1}, {0}}}, | |||||
| {kFusedBatchNormEx, {{0}, {0}}}, | {kFusedBatchNormEx, {{0}, {0}}}, | ||||
| {kFusedBatchNormExWithActivation, {{0}, {0}}}, | {kFusedBatchNormExWithActivation, {{0}, {0}}}, | ||||
| {kFusedBatchNormExWithAddAndActivation, {{0, 5}, {0}}}, | {kFusedBatchNormExWithAddAndActivation, {{0, 5}, {0}}}, | ||||
| {kFusedBatchNormGradEx, {{0, 1}, {0}}}, | {kFusedBatchNormGradEx, {{0, 1}, {0}}}, | ||||
| {kFusedBatchNormGradExWithActivation, {{0, 1, 7}, {0}}}, | {kFusedBatchNormGradExWithActivation, {{0, 1, 7}, {0}}}, | ||||
| {kFusedBatchNormGradExWithAddAndActivation, {{0, 1, 7}, {0, 3}}}, | {kFusedBatchNormGradExWithAddAndActivation, {{0, 1, 7}, {0, 3}}}, | ||||
| {kBiasAddOpName, {{0}, {0}}}, | |||||
| {prim::kPrimBiasAddGrad->name(), {{0}, {}}}, | |||||
| // Format insensitive. | |||||
| {prim::kPrimRelu->name(), {{0}, {0}}}, | |||||
| {prim::kPrimReluGrad->name(), {{0, 1}, {0}}}, | |||||
| {kSliceOpName, {{0}, {0}}}, | |||||
| {kTensorAddOpName, {{0, 1}, {0}}}, | |||||
| {prim::kPrimConcat->name(), {{}, {0}}}, | {prim::kPrimConcat->name(), {{}, {0}}}, | ||||
| {prim::kPrimAddN->name(), {{}, {0}}}, | {prim::kPrimAddN->name(), {{}, {0}}}, | ||||
| }; | }; | ||||
| @@ -74,8 +78,6 @@ class FormatTransformChecker { | |||||
| FormatTransformChecker &operator=(const FormatTransformChecker &); | FormatTransformChecker &operator=(const FormatTransformChecker &); | ||||
| bool format_transform_{true}; | bool format_transform_{true}; | ||||
| static constexpr size_t kConv2dCount = 96; | |||||
| static constexpr size_t kFusedBatchNormCount = 94; | |||||
| }; | }; | ||||
| class KernelAttr { | class KernelAttr { | ||||