|
|
|
@@ -189,7 +189,7 @@ size_t CubeSizeByType(const TypeId data_type) { |
|
|
|
const size_t default_error = 0; |
|
|
|
auto dt_size = abstract::TypeIdSize(data_type); |
|
|
|
if (dt_size < 1) { |
|
|
|
MS_LOG(EXCEPTION) << "Illegal dtype."; |
|
|
|
MS_LOG(ERROR) << "Illegal dtype."; |
|
|
|
return default_error; |
|
|
|
} else if (dt_size == 1) { |
|
|
|
return kCubeSize * 2; |
|
|
|
@@ -206,14 +206,14 @@ bool CheckDims(const std::vector<size_t> &shape) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape, const TypeId &type) { |
|
|
|
std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) { |
|
|
|
if (!CheckDims(shape)) { |
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed."; |
|
|
|
} |
|
|
|
return shape; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape, const TypeId &type) { |
|
|
|
std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) { |
|
|
|
if (!CheckDims(shape)) { |
|
|
|
MS_LOG(EXCEPTION) << "Ccheck dims failed."; |
|
|
|
} |
|
|
|
@@ -225,7 +225,7 @@ std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape, const Type |
|
|
|
return device_shape; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape, const TypeId &type) { |
|
|
|
std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) { |
|
|
|
if (!CheckDims(shape)) { |
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed."; |
|
|
|
} |
|
|
|
@@ -237,29 +237,27 @@ std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape, const Type |
|
|
|
return device_shape; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape, const TypeId &type) { |
|
|
|
std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) { |
|
|
|
if (!CheckDims(shape)) { |
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed."; |
|
|
|
} |
|
|
|
auto kCube = CubeSizeByType(type); |
|
|
|
std::vector<size_t> device_shape; |
|
|
|
auto c1 = DivCeil(shape[kC], kCube); |
|
|
|
auto n0 = DivCeil(shape[kN], kCubeSize); |
|
|
|
device_shape.push_back(shape[kH] * shape[kW] * c1); |
|
|
|
device_shape.push_back(n0); |
|
|
|
const size_t cout16 = ((shape[kN] + kCubeSize - 1) / kCubeSize) * kCubeSize; |
|
|
|
const size_t cin16 = ((shape[kC] + kCubeSize - 1) / kCubeSize) * kCubeSize; |
|
|
|
device_shape.push_back(shape[kH] * shape[kW] * cin16 / kCubeSize); |
|
|
|
device_shape.push_back(cout16 / kCubeSize); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
device_shape.push_back(kCube); |
|
|
|
return device_shape; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape, const TypeId &type) { |
|
|
|
std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) { |
|
|
|
if (!CheckDims(shape)) { |
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed."; |
|
|
|
} |
|
|
|
auto kCube = CubeSizeByType(type); |
|
|
|
std::vector<size_t> device_shape; |
|
|
|
const size_t C1 = (shape[kC] + kCube - 1) / kCube; |
|
|
|
const size_t C0 = kCube; |
|
|
|
const size_t C1 = (shape[kC] + kCubeSize - 1) / kCubeSize; |
|
|
|
const size_t C0 = kCubeSize; |
|
|
|
device_shape.push_back(shape[kN]); |
|
|
|
device_shape.push_back(C1); |
|
|
|
device_shape.push_back(shape[kH]); |
|
|
|
@@ -268,7 +266,7 @@ std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape, const T |
|
|
|
return device_shape; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape, const TypeId &type) { |
|
|
|
std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) { |
|
|
|
// NCDHW |
|
|
|
if (shape.size() != 5) { |
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size(); |
|
|
|
@@ -285,54 +283,51 @@ std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape, const |
|
|
|
return device_shape; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape, const TypeId &type) { |
|
|
|
std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) { |
|
|
|
// NCDHW -> Frac_Z_3D |
|
|
|
if (shape.size() != 5) { |
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size(); |
|
|
|
} |
|
|
|
auto kCube = CubeSizeByType(type); |
|
|
|
std::vector<size_t> device_shape; |
|
|
|
const size_t C1 = (shape[1] + kCube - 1) / kCube; |
|
|
|
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; |
|
|
|
const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize; |
|
|
|
device_shape.push_back(shape[2] * C1 * shape[3] * shape[4]); |
|
|
|
device_shape.push_back(N1); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
device_shape.push_back(kCube); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
return device_shape; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape, const TypeId &type) { |
|
|
|
std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) { |
|
|
|
if (!CheckDims(shape)) { |
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed."; |
|
|
|
} |
|
|
|
auto kCube = CubeSizeByType(type); |
|
|
|
std::vector<size_t> device_shape; |
|
|
|
device_shape.push_back((shape[kC] - 1) / kCube + 1); |
|
|
|
device_shape.push_back((shape[kC] - 1) / kCubeSize + 1); |
|
|
|
device_shape.push_back(shape[kH]); |
|
|
|
device_shape.push_back(shape[kW]); |
|
|
|
device_shape.push_back(shape[kN]); |
|
|
|
device_shape.push_back(kCube); |
|
|
|
device_shape.push_back(kCube); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
return device_shape; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape, const TypeId &type) { |
|
|
|
std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) { |
|
|
|
if (!CheckDims(shape)) { |
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed."; |
|
|
|
} |
|
|
|
auto kCube = CubeSizeByType(type); |
|
|
|
std::vector<size_t> device_shape; |
|
|
|
const size_t c0 = 4; |
|
|
|
auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCube); |
|
|
|
auto no = DivCeil(shape.at(kN), kCube); |
|
|
|
auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCubeSize); |
|
|
|
auto no = DivCeil(shape.at(kN), kCubeSize); |
|
|
|
device_shape.push_back(first_dim); |
|
|
|
device_shape.push_back(no); |
|
|
|
device_shape.push_back(kCube); |
|
|
|
device_shape.push_back(kCube); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
return device_shape; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape, const TypeId &type) { |
|
|
|
std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) { |
|
|
|
if (!CheckDims(shape)) { |
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed."; |
|
|
|
} |
|
|
|
@@ -347,7 +342,7 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape, const |
|
|
|
return device_shape; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape, const TypeId &type) { |
|
|
|
std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) { |
|
|
|
if (shape.size() < kNdhwc) { |
|
|
|
MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc."; |
|
|
|
} |
|
|
|
@@ -432,9 +427,8 @@ std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std |
|
|
|
return shape_4d; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format, |
|
|
|
const TypeId &type) { |
|
|
|
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &, const TypeId &)>; |
|
|
|
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) { |
|
|
|
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>; |
|
|
|
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape}, |
|
|
|
{kOpFormat_NHWC, NhwcDeviceShape}, |
|
|
|
{kOpFormat_HWCN, HwchDeviceShape}, |
|
|
|
@@ -452,9 +446,8 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s |
|
|
|
} |
|
|
|
auto temp_shape = shape; |
|
|
|
std::vector<size_t> device_shape; |
|
|
|
auto kCube = CubeSizeByType(type); |
|
|
|
if (format == kOpFormat_FRAC_NZ) { |
|
|
|
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCube == 0)) { |
|
|
|
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) { |
|
|
|
// For [1] and [1024] shape we can trait it as NZ shape |
|
|
|
return shape; |
|
|
|
} |
|
|
|
@@ -463,12 +456,12 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s |
|
|
|
} else { |
|
|
|
(void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape)); |
|
|
|
} |
|
|
|
auto w1 = (shape[shape.size() - 1] - 1) / kCube + 1; |
|
|
|
auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1; |
|
|
|
auto w1 = (shape[shape.size() - 1] - 1) / kCubeSize + 1; |
|
|
|
device_shape.push_back(w1); |
|
|
|
device_shape.push_back(h1); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
device_shape.push_back(kCube); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
return device_shape; |
|
|
|
} else if (format == kOpFormat_FRACTAL_ZN_LSTM) { |
|
|
|
const size_t c0 = 4; |
|
|
|
@@ -490,7 +483,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s |
|
|
|
if (iter == device_shape_map.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]"; |
|
|
|
} |
|
|
|
return iter->second(temp_shape, type); |
|
|
|
return iter->second(temp_shape); |
|
|
|
} |
|
|
|
|
|
|
|
bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) { |
|
|
|
|