|
|
|
@@ -27,6 +27,7 @@ |
|
|
|
#include "utils/ms_utils.h" |
|
|
|
#include "backend/kernel_compiler/oplib/oplib.h" |
|
|
|
#include "backend/kernel_compiler/oplib/opinfo.h" |
|
|
|
#include "runtime/device/gpu/cuda_common.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace device { |
|
|
|
@@ -236,6 +237,12 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI |
|
|
|
} // namespace |
|
|
|
|
|
|
|
void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph) { |
|
|
|
// TensorCore can be used only in Volta or newer devices. |
|
|
|
const int marjor_sm = GET_MAJOR_SM; |
|
|
|
if (marjor_sm < RECOMMEND_SM) { |
|
|
|
format_transform_ = false; |
|
|
|
return; |
|
|
|
} |
|
|
|
auto kernels = kernel_graph->execution_order(); |
|
|
|
size_t conv_cnt = 0; |
|
|
|
size_t bn_cnt = 0; |
|
|
|
|