|
|
|
@@ -620,7 +620,27 @@ std::vector<int64_t> FracZDeviceShapeWithGroups(const std::vector<int64_t> &shap |
|
|
|
return device_shape; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int64_t> TransShapeToFracNZ(const std::vector<int64_t> &shape) { |
|
|
|
std::vector<size_t> FracNZDeviceShape(const std::vector<size_t> &shape) { |
|
|
|
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) { |
|
|
|
// For [1] and [1024] shape we can trait it as NZ shape |
|
|
|
return shape; |
|
|
|
} |
|
|
|
std::vector<size_t> device_shape; |
|
|
|
if (shape.size() < 2) { |
|
|
|
MS_LOG(EXCEPTION) << "Format FRACTAL_NZ is not support shape " << shape.size(); |
|
|
|
} else { |
|
|
|
(void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape)); |
|
|
|
} |
|
|
|
auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1; |
|
|
|
auto w1 = (shape[shape.size() - 1] - 1) / kCubeSize + 1; |
|
|
|
device_shape.push_back(w1); |
|
|
|
device_shape.push_back(h1); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
return device_shape; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int64_t> FracNZDeviceDynamicShape(const std::vector<int64_t> &shape) { |
|
|
|
std::vector<int64_t> device_shape; |
|
|
|
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) { |
|
|
|
// For [1] and [1024] shape we can trait it as NZ shape |
|
|
|
@@ -642,7 +662,21 @@ std::vector<int64_t> TransShapeToFracNZ(const std::vector<int64_t> &shape) { |
|
|
|
return device_shape; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int64_t> TransShapeToFracNZLSTM(const std::vector<int64_t> &shape) { |
|
|
|
std::vector<size_t> FracNZLSTMDeviceShape(const std::vector<size_t> &shape) { |
|
|
|
const size_t c0 = 4; |
|
|
|
const size_t h = shape.at(kN) / c0; |
|
|
|
const size_t i = shape.at(kC) - h; |
|
|
|
const size_t first = DivCeil(i, kCubeSize) + DivCeil(h, kCubeSize); |
|
|
|
const size_t second = c0 * DivCeil(h, kCubeSize); |
|
|
|
std::vector<size_t> device_shape; |
|
|
|
device_shape.push_back(first); |
|
|
|
device_shape.push_back(second); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
return device_shape; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int64_t> FracNZLSTMDeviceDynamicShape(const std::vector<int64_t> &shape) { |
|
|
|
std::vector<int64_t> device_shape; |
|
|
|
const int64_t c0 = 4; |
|
|
|
const int64_t h_shape = shape.at(kN); |
|
|
|
@@ -693,8 +727,8 @@ bool IsNeedPadding(const std::string &format, const size_t shape_size) { |
|
|
|
if (shape_size == 0) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ || format == kOpFormat_ChannelLast || |
|
|
|
format == kOpFormat_NCHW) { |
|
|
|
if (format == kOpFormat_DEFAULT || format == kOpFormat_NCHW || |
|
|
|
kNoPaddingFormatSet.find(format) != kNoPaddingFormatSet.end()) { |
|
|
|
return false; |
|
|
|
} else if (shape_size < kNchwDims) { |
|
|
|
return true; |
|
|
|
@@ -799,7 +833,9 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s |
|
|
|
{kOpFormat_NCDHW, NcdhwDeviceShape}, |
|
|
|
{kOpFormat_ChannelLast, ChannelLastDeviceShape}, |
|
|
|
{kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape}, |
|
|
|
{kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape}}; |
|
|
|
{kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape}, |
|
|
|
{kOpFormat_FRAC_NZ, FracNZDeviceShape}, |
|
|
|
{kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceShape}}; |
|
|
|
|
|
|
|
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { |
|
|
|
return shape; |
|
|
|
@@ -808,37 +844,8 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s |
|
|
|
return FracZDeviceShapeWithGroups(shape, groups); |
|
|
|
} |
|
|
|
auto temp_shape = shape; |
|
|
|
std::vector<size_t> device_shape; |
|
|
|
if (format == kOpFormat_FRAC_NZ) { |
|
|
|
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) { |
|
|
|
// For [1] and [1024] shape we can trait it as NZ shape |
|
|
|
return shape; |
|
|
|
} |
|
|
|
if (shape.size() < 2) { |
|
|
|
MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size(); |
|
|
|
} else { |
|
|
|
(void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape)); |
|
|
|
} |
|
|
|
auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1; |
|
|
|
auto w1 = (shape[shape.size() - 1] - 1) / kCubeSize + 1; |
|
|
|
device_shape.push_back(w1); |
|
|
|
device_shape.push_back(h1); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
return device_shape; |
|
|
|
} else if (format == kOpFormat_FRACTAL_ZN_LSTM) { |
|
|
|
const size_t c0 = 4; |
|
|
|
const size_t h = shape.at(kN) / c0; |
|
|
|
const size_t i = shape.at(kC) - h; |
|
|
|
const size_t first = DivCeil(i, kCubeSize) + DivCeil(h, kCubeSize); |
|
|
|
const size_t second = c0 * DivCeil(h, kCubeSize); |
|
|
|
device_shape.push_back(first); |
|
|
|
device_shape.push_back(second); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
device_shape.push_back(kCubeSize); |
|
|
|
return device_shape; |
|
|
|
} |
|
|
|
if (format != kOpFormat_ChannelLast && shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) { |
|
|
|
if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM && |
|
|
|
shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) { |
|
|
|
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; |
|
|
|
temp_shape = PaddingShapeTo4dDefault(shape); |
|
|
|
} |
|
|
|
@@ -867,7 +874,9 @@ std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const |
|
|
|
{kOpFormat_NCDHW, NcdhwDeviceDynamicShape}, |
|
|
|
{kOpFormat_ChannelLast, ChannelLastDeviceDynamicShape}, |
|
|
|
{kOpFormat_NDC1HWC0, Ndc1hwc0DeviceDynamicShape}, |
|
|
|
{kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceDynamicShape}}; |
|
|
|
{kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceDynamicShape}, |
|
|
|
{kOpFormat_FRAC_NZ, FracNZDeviceDynamicShape}, |
|
|
|
{kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceDynamicShape}}; |
|
|
|
|
|
|
|
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT || format == kOpFormat_NCHW) { |
|
|
|
return shape; |
|
|
|
@@ -876,12 +885,8 @@ std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const |
|
|
|
return FracZDeviceShapeWithGroups(shape, groups); |
|
|
|
} |
|
|
|
auto temp_shape = shape; |
|
|
|
if (format == kOpFormat_FRAC_NZ) { |
|
|
|
return TransShapeToFracNZ(shape); |
|
|
|
} else if (format == kOpFormat_FRACTAL_ZN_LSTM) { |
|
|
|
return TransShapeToFracNZLSTM(shape); |
|
|
|
} |
|
|
|
if (format != kOpFormat_ChannelLast && shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) { |
|
|
|
if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM && |
|
|
|
shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) { |
|
|
|
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; |
|
|
|
temp_shape = PaddingShapeTo4dDefault(shape); |
|
|
|
} |
|
|
|
@@ -1219,6 +1224,7 @@ bool NchwToNc1hwc04(const FormatArgs &args, void *result) { |
|
|
|
MS_LOG(DEBUG) << "Trans format from nchw to Nc1hwc04."; |
|
|
|
return NchwToNc1hwc0(args, result); |
|
|
|
} |
|
|
|
|
|
|
|
bool Nc1hwc04ToNchw(const FormatArgs &args, void *result) { |
|
|
|
MS_LOG(DEBUG) << "Trans format from Nc1hwc04 to nchw."; |
|
|
|
return Nc1hwc0ToNchw(args, result); |
|
|
|
|