|
|
|
@@ -22,6 +22,7 @@ |
|
|
|
#include <string>
|
|
|
|
#include <vector>
|
|
|
|
#include "kernel/kernel.h"
|
|
|
|
#include "kernel/gpu/kernel_constants.h"
|
|
|
|
#include "device/gpu/gpu_device_manager.h"
|
|
|
|
#include "device/gpu/gpu_common.h"
|
|
|
|
#include "session/anf_runtime_algorithm.h"
|
|
|
|
@@ -79,6 +80,22 @@ class GpuKernel : public KernelMod { |
|
|
|
"must match the corresponding dimension of outC or must be equal to 1.";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// choose the suitable datatype for cudnn/cublas
|
|
|
|
inline cudnnDataType_t GetCudnnDataType(const std::string &Type) {
|
|
|
|
auto type = kCudnnDtypeMap.find(Type);
|
|
|
|
if (type == kCudnnDtypeMap.end()) {
|
|
|
|
MS_EXCEPTION(TypeError) << Type << " is not supported.";
|
|
|
|
}
|
|
|
|
return type->second;
|
|
|
|
}
|
|
|
|
inline cudaDataType_t GetCudaDataType(const std::string &Type) {
|
|
|
|
auto type = kCudaDtypeMap.find(Type);
|
|
|
|
if (type == kCudaDtypeMap.end()) {
|
|
|
|
MS_EXCEPTION(TypeError) << Type << " is not supported.";
|
|
|
|
}
|
|
|
|
return type->second;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace kernel
|
|
|
|
} // namespace mindspore
|
|
|
|
|