Browse Source

refactor TransShape

tags/v1.5.0-rc1
yuchaojie 4 years ago
parent
commit
0d56b1853c
2 changed files with 51 additions and 43 deletions
  1. +49
    -43
      mindspore/ccsrc/common/trans.cc
  2. +2
    -0
      mindspore/ccsrc/utils/utils.h

+ 49
- 43
mindspore/ccsrc/common/trans.cc View File

@@ -620,7 +620,27 @@ std::vector<int64_t> FracZDeviceShapeWithGroups(const std::vector<int64_t> &shap
return device_shape;
}

std::vector<int64_t> TransShapeToFracNZ(const std::vector<int64_t> &shape) {
std::vector<size_t> FracNZDeviceShape(const std::vector<size_t> &shape) {
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
// For [1] and [1024] shape we can trait it as NZ shape
return shape;
}
std::vector<size_t> device_shape;
if (shape.size() < 2) {
MS_LOG(EXCEPTION) << "Format FRACTAL_NZ is not support shape " << shape.size();
} else {
(void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
}
auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1;
auto w1 = (shape[shape.size() - 1] - 1) / kCubeSize + 1;
device_shape.push_back(w1);
device_shape.push_back(h1);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
}

std::vector<int64_t> FracNZDeviceDynamicShape(const std::vector<int64_t> &shape) {
std::vector<int64_t> device_shape;
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
// For [1] and [1024] shape we can trait it as NZ shape
@@ -642,7 +662,21 @@ std::vector<int64_t> TransShapeToFracNZ(const std::vector<int64_t> &shape) {
return device_shape;
}

std::vector<int64_t> TransShapeToFracNZLSTM(const std::vector<int64_t> &shape) {
std::vector<size_t> FracNZLSTMDeviceShape(const std::vector<size_t> &shape) {
const size_t c0 = 4;
const size_t h = shape.at(kN) / c0;
const size_t i = shape.at(kC) - h;
const size_t first = DivCeil(i, kCubeSize) + DivCeil(h, kCubeSize);
const size_t second = c0 * DivCeil(h, kCubeSize);
std::vector<size_t> device_shape;
device_shape.push_back(first);
device_shape.push_back(second);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
}

std::vector<int64_t> FracNZLSTMDeviceDynamicShape(const std::vector<int64_t> &shape) {
std::vector<int64_t> device_shape;
const int64_t c0 = 4;
const int64_t h_shape = shape.at(kN);
@@ -693,8 +727,8 @@ bool IsNeedPadding(const std::string &format, const size_t shape_size) {
if (shape_size == 0) {
return false;
}
if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ || format == kOpFormat_ChannelLast ||
format == kOpFormat_NCHW) {
if (format == kOpFormat_DEFAULT || format == kOpFormat_NCHW ||
kNoPaddingFormatSet.find(format) != kNoPaddingFormatSet.end()) {
return false;
} else if (shape_size < kNchwDims) {
return true;
@@ -799,7 +833,9 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
{kOpFormat_NCDHW, NcdhwDeviceShape},
{kOpFormat_ChannelLast, ChannelLastDeviceShape},
{kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape},
{kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape}};
{kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape},
{kOpFormat_FRAC_NZ, FracNZDeviceShape},
{kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceShape}};

if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
return shape;
@@ -808,37 +844,8 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
return FracZDeviceShapeWithGroups(shape, groups);
}
auto temp_shape = shape;
std::vector<size_t> device_shape;
if (format == kOpFormat_FRAC_NZ) {
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
// For [1] and [1024] shape we can trait it as NZ shape
return shape;
}
if (shape.size() < 2) {
MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size();
} else {
(void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
}
auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1;
auto w1 = (shape[shape.size() - 1] - 1) / kCubeSize + 1;
device_shape.push_back(w1);
device_shape.push_back(h1);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
} else if (format == kOpFormat_FRACTAL_ZN_LSTM) {
const size_t c0 = 4;
const size_t h = shape.at(kN) / c0;
const size_t i = shape.at(kC) - h;
const size_t first = DivCeil(i, kCubeSize) + DivCeil(h, kCubeSize);
const size_t second = c0 * DivCeil(h, kCubeSize);
device_shape.push_back(first);
device_shape.push_back(second);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
}
if (format != kOpFormat_ChannelLast && shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM &&
shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
temp_shape = PaddingShapeTo4dDefault(shape);
}
@@ -867,7 +874,9 @@ std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const
{kOpFormat_NCDHW, NcdhwDeviceDynamicShape},
{kOpFormat_ChannelLast, ChannelLastDeviceDynamicShape},
{kOpFormat_NDC1HWC0, Ndc1hwc0DeviceDynamicShape},
{kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceDynamicShape}};
{kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceDynamicShape},
{kOpFormat_FRAC_NZ, FracNZDeviceDynamicShape},
{kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceDynamicShape}};

if (format == kOpFormat_ND || format == kOpFormat_DEFAULT || format == kOpFormat_NCHW) {
return shape;
@@ -876,12 +885,8 @@ std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const
return FracZDeviceShapeWithGroups(shape, groups);
}
auto temp_shape = shape;
if (format == kOpFormat_FRAC_NZ) {
return TransShapeToFracNZ(shape);
} else if (format == kOpFormat_FRACTAL_ZN_LSTM) {
return TransShapeToFracNZLSTM(shape);
}
if (format != kOpFormat_ChannelLast && shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM &&
shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
temp_shape = PaddingShapeTo4dDefault(shape);
}
@@ -1219,6 +1224,7 @@ bool NchwToNc1hwc04(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans format from nchw to Nc1hwc04.";
return NchwToNc1hwc0(args, result);
}

bool Nc1hwc04ToNchw(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans format from Nc1hwc04 to nchw.";
return Nc1hwc0ToNchw(args, result);


+ 2
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -633,6 +633,8 @@ const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccid
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};

const std::set<std::string> kNoPaddingFormatSet = {kOpFormat_ChannelLast, kOpFormat_FRAC_NZ};

const std::set<std::string> DynamicShapeConstInputToAttr = {
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceMinOpName,
kReduceMeanOpName, kReduceMaxOpName, kReduceAllOpName, kReduceAnyOpName, kConcatOpName};


Loading…
Cancel
Save