| @@ -14,11 +14,9 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "common/trans.h" | #include "common/trans.h" | ||||
| #include <algorithm> | |||||
| #include <functional> | #include <functional> | ||||
| #include <numeric> | #include <numeric> | ||||
| #include <utility> | #include <utility> | ||||
| #include "./securec.h" | |||||
| #include "common/utils.h" | #include "common/utils.h" | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "kernel/kernel.h" | #include "kernel/kernel.h" | ||||
| @@ -29,34 +27,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace trans { | namespace trans { | ||||
| namespace { | |||||
| std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) { | |||||
| std::vector<size_t> shape_4d(4, 1); | |||||
| switch (shape.size()) { | |||||
| case 0: | |||||
| return shape_4d; | |||||
| case 1: | |||||
| shape_4d[1] = shape[0]; | |||||
| break; | |||||
| case 2: | |||||
| shape_4d[1] = shape[0]; | |||||
| shape_4d[2] = shape[1]; | |||||
| break; | |||||
| case 3: | |||||
| shape_4d[1] = shape[0]; | |||||
| shape_4d[2] = shape[1]; | |||||
| shape_4d[3] = shape[2]; | |||||
| break; | |||||
| case 4: | |||||
| std::copy(shape.begin(), shape.end(), shape_4d.begin()); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size(); | |||||
| } | |||||
| return shape_4d; | |||||
| } | |||||
| } // namespace | |||||
| const size_t kNchwDims = 4; | |||||
| enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc }; | |||||
| const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1}, | const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1}, | ||||
| {kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8}, | {kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8}, | ||||
| {kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2}, | {kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2}, | ||||
| @@ -84,7 +55,10 @@ inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, | |||||
| template <typename T> | template <typename T> | ||||
| T DivCeil(T n1, T n2) { | T DivCeil(T n1, T n2) { | ||||
| return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; | |||||
| if (n2 != 0) { | |||||
| return (n1 - 1) / n2 + 1; | |||||
| } | |||||
| return 0; | |||||
| } | } | ||||
| enum DataTypeTransMode { | enum DataTypeTransMode { | ||||
| @@ -226,8 +200,7 @@ size_t CubeSizeByType(const TypeId data_type) { | |||||
| } | } | ||||
| size_t ShapeSize(const std::vector<size_t> &shape) { | size_t ShapeSize(const std::vector<size_t> &shape) { | ||||
| size_t product = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>()); | |||||
| return product; | |||||
| return std::accumulate(shape.begin(), shape.end(), IntToSize(1), std::multiplies<size_t>()); | |||||
| } | } | ||||
| size_t TypeIdSize(const TypeId data_type) { | size_t TypeIdSize(const TypeId data_type) { | ||||
| @@ -239,57 +212,9 @@ size_t TypeIdSize(const TypeId data_type) { | |||||
| return unsupported_type_error; | return unsupported_type_error; | ||||
| } | } | ||||
| bool IsNeedPadding(const std::string &format, const size_t shape_size) { | |||||
| if (shape_size == 0) { | |||||
| return false; | |||||
| } | |||||
| if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) { | |||||
| return false; | |||||
| } else if (shape_size < 4) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) { | |||||
| std::vector<int> shape; | |||||
| std::vector<size_t> host_shape; | |||||
| if (node->isa<ValueNode>()) { | |||||
| auto value_node = node->cast<ValueNodePtr>(); | |||||
| auto node_value = value_node->value(); | |||||
| auto tensor = node_value->cast<tensor::TensorPtr>(); | |||||
| if (tensor == nullptr) { | |||||
| MS_LOG(EXCEPTION) << " the node[ " << node->DebugString() << "]'s cannot convert "; | |||||
| } | |||||
| auto shape_temp = tensor->shape(); | |||||
| (void)std::transform(shape_temp.begin(), shape_temp.end(), std::back_inserter(host_shape), IntToSize); | |||||
| if (host_shape.empty()) { | |||||
| host_shape.push_back(1); | |||||
| } | |||||
| } else { | |||||
| host_shape = AnfAlgo::GetOutputInferShape(node, index); | |||||
| } | |||||
| if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, 0), host_shape.size())) { | |||||
| host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, 0)); | |||||
| } | |||||
| std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt); | |||||
| return shape; | |||||
| } | |||||
| std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<kernel::Axis> &padding_axis) { | |||||
| if (padding_axis.empty() || shape.size() != padding_axis.size()) { | |||||
| return PaddingShapeTo4dByDefault(shape); | |||||
| } | |||||
| std::vector<size_t> shape_4d(4, 1); | |||||
| for (size_t index = 0; index < padding_axis.size(); index++) { | |||||
| shape_4d[padding_axis[index]] = shape[index]; | |||||
| } | |||||
| return shape_4d; | |||||
| } | |||||
| namespace { | namespace { | ||||
| bool CheckDims(const std::vector<size_t> &shape) { | bool CheckDims(const std::vector<size_t> &shape) { | ||||
| if (shape.size() != 4) { | |||||
| if (shape.size() != kNchwDims) { | |||||
| MS_LOG(ERROR) << "Host shape dims shoud be 4"; | MS_LOG(ERROR) << "Host shape dims shoud be 4"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -308,10 +233,10 @@ std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) { | |||||
| MS_LOG(EXCEPTION) << "Ccheck dims failed."; | MS_LOG(EXCEPTION) << "Ccheck dims failed."; | ||||
| } | } | ||||
| std::vector<size_t> device_shape; | std::vector<size_t> device_shape; | ||||
| device_shape.push_back(shape[0]); | |||||
| device_shape.push_back(shape[2]); | |||||
| device_shape.push_back(shape[3]); | |||||
| device_shape.push_back(shape[1]); | |||||
| device_shape.push_back(shape[kN]); | |||||
| device_shape.push_back(shape[kH]); | |||||
| device_shape.push_back(shape[kW]); | |||||
| device_shape.push_back(shape[kC]); | |||||
| return device_shape; | return device_shape; | ||||
| } | } | ||||
| @@ -320,10 +245,10 @@ std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) { | |||||
| MS_LOG(EXCEPTION) << "Check dims failed."; | MS_LOG(EXCEPTION) << "Check dims failed."; | ||||
| } | } | ||||
| std::vector<size_t> device_shape; | std::vector<size_t> device_shape; | ||||
| device_shape.push_back(shape[2]); | |||||
| device_shape.push_back(shape[3]); | |||||
| device_shape.push_back(shape[1]); | |||||
| device_shape.push_back(shape[0]); | |||||
| device_shape.push_back(shape[kH]); | |||||
| device_shape.push_back(shape[kW]); | |||||
| device_shape.push_back(shape[kC]); | |||||
| device_shape.push_back(shape[kN]); | |||||
| return device_shape; | return device_shape; | ||||
| } | } | ||||
| @@ -332,9 +257,9 @@ std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) { | |||||
| MS_LOG(EXCEPTION) << "Check dims failed."; | MS_LOG(EXCEPTION) << "Check dims failed."; | ||||
| } | } | ||||
| std::vector<size_t> device_shape; | std::vector<size_t> device_shape; | ||||
| size_t cout16 = ((shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize; | |||||
| size_t cin16 = ((shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize; | |||||
| device_shape.push_back(shape[2] * shape[3] * cin16 / kCubeSize); | |||||
| const size_t cout16 = ((shape[kN] + kCubeSize - 1) / kCubeSize) * kCubeSize; | |||||
| const size_t cin16 = ((shape[kC] + kCubeSize - 1) / kCubeSize) * kCubeSize; | |||||
| device_shape.push_back(shape[kH] * shape[kW] * cin16 / kCubeSize); | |||||
| device_shape.push_back(cout16 / kCubeSize); | device_shape.push_back(cout16 / kCubeSize); | ||||
| device_shape.push_back(kCubeSize); | device_shape.push_back(kCubeSize); | ||||
| device_shape.push_back(kCubeSize); | device_shape.push_back(kCubeSize); | ||||
| @@ -346,12 +271,12 @@ std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) { | |||||
| MS_LOG(EXCEPTION) << "Check dims failed."; | MS_LOG(EXCEPTION) << "Check dims failed."; | ||||
| } | } | ||||
| std::vector<size_t> device_shape; | std::vector<size_t> device_shape; | ||||
| size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; | |||||
| size_t C0 = kCubeSize; | |||||
| device_shape.push_back(shape[0]); | |||||
| const size_t C1 = (shape[kC] + kCubeSize - 1) / kCubeSize; | |||||
| const size_t C0 = kCubeSize; | |||||
| device_shape.push_back(shape[kN]); | |||||
| device_shape.push_back(C1); | device_shape.push_back(C1); | ||||
| device_shape.push_back(shape[2]); | |||||
| device_shape.push_back(shape[3]); | |||||
| device_shape.push_back(shape[kH]); | |||||
| device_shape.push_back(shape[kW]); | |||||
| device_shape.push_back(C0); | device_shape.push_back(C0); | ||||
| return device_shape; | return device_shape; | ||||
| } | } | ||||
| @@ -361,10 +286,10 @@ std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) { | |||||
| MS_LOG(EXCEPTION) << "Check dims failed."; | MS_LOG(EXCEPTION) << "Check dims failed."; | ||||
| } | } | ||||
| std::vector<size_t> device_shape; | std::vector<size_t> device_shape; | ||||
| device_shape.push_back((shape[1] - 1) / kCubeSize + 1); | |||||
| device_shape.push_back(shape[2]); | |||||
| device_shape.push_back(shape[3]); | |||||
| device_shape.push_back(shape[0]); | |||||
| device_shape.push_back((shape[kC] - 1) / kCubeSize + 1); | |||||
| device_shape.push_back(shape[kH]); | |||||
| device_shape.push_back(shape[kW]); | |||||
| device_shape.push_back(shape[kN]); | |||||
| device_shape.push_back(kCubeSize); | device_shape.push_back(kCubeSize); | ||||
| device_shape.push_back(kCubeSize); | device_shape.push_back(kCubeSize); | ||||
| return device_shape; | return device_shape; | ||||
| @@ -375,9 +300,9 @@ std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) { | |||||
| MS_LOG(EXCEPTION) << "Check dims failed."; | MS_LOG(EXCEPTION) << "Check dims failed."; | ||||
| } | } | ||||
| std::vector<size_t> device_shape; | std::vector<size_t> device_shape; | ||||
| size_t c0 = 4; | |||||
| auto first_dim = DivCeil(c0 * shape.at(2) * shape.at(3), kCubeSize); | |||||
| auto no = DivCeil(shape.at(0), kCubeSize); | |||||
| const size_t c0 = 4; | |||||
| auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCubeSize); | |||||
| auto no = DivCeil(shape.at(kN), kCubeSize); | |||||
| device_shape.push_back(first_dim); | device_shape.push_back(first_dim); | ||||
| device_shape.push_back(no); | device_shape.push_back(no); | ||||
| device_shape.push_back(kCubeSize); | device_shape.push_back(kCubeSize); | ||||
| @@ -390,24 +315,101 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) { | |||||
| MS_LOG(EXCEPTION) << "Check dims failed."; | MS_LOG(EXCEPTION) << "Check dims failed."; | ||||
| } | } | ||||
| std::vector<size_t> device_shape; | std::vector<size_t> device_shape; | ||||
| size_t C1 = 1; | |||||
| size_t C0 = 4; | |||||
| device_shape.push_back(shape[0]); | |||||
| const size_t C1 = 1; | |||||
| const size_t C0 = 4; | |||||
| device_shape.push_back(shape[kN]); | |||||
| device_shape.push_back(C1); | device_shape.push_back(C1); | ||||
| device_shape.push_back(shape[2]); | |||||
| device_shape.push_back(shape[3]); | |||||
| device_shape.push_back(shape[kH]); | |||||
| device_shape.push_back(shape[kW]); | |||||
| device_shape.push_back(C0); | device_shape.push_back(C0); | ||||
| return device_shape; | return device_shape; | ||||
| } | } | ||||
| std::vector<size_t> NdhwcDeviceShape(const std::vector<size_t> &shape) { | std::vector<size_t> NdhwcDeviceShape(const std::vector<size_t> &shape) { | ||||
| if (shape.size() < 5) { | |||||
| if (shape.size() < kNdhwc) { | |||||
| MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc."; | MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc."; | ||||
| } | } | ||||
| return shape; | return shape; | ||||
| } | } | ||||
| std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) { | |||||
| std::vector<size_t> shape_4d(kNchwDims, 1); | |||||
| switch (shape.size()) { | |||||
| case 0: | |||||
| return shape_4d; | |||||
| case 1: | |||||
| shape_4d[kC] = shape[kN]; | |||||
| break; | |||||
| case 2: | |||||
| shape_4d[kC] = shape[kN]; | |||||
| shape_4d[kH] = shape[kC]; | |||||
| break; | |||||
| case 3: | |||||
| shape_4d[kC] = shape[kN]; | |||||
| shape_4d[kH] = shape[kC]; | |||||
| shape_4d[kW] = shape[kH]; | |||||
| break; | |||||
| case 4: | |||||
| std::copy(shape.begin(), shape.end(), shape_4d.begin()); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size(); | |||||
| } | |||||
| return shape_4d; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| bool IsNeedPadding(const std::string &format, const size_t shape_size) { | |||||
| if (shape_size == 0) { | |||||
| return false; | |||||
| } | |||||
| if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) { | |||||
| return false; | |||||
| } else if (shape_size < kNchwDims) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| std::vector<int> shape; | |||||
| std::vector<size_t> host_shape; | |||||
| if (node->isa<ValueNode>()) { | |||||
| auto value_node = node->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| auto node_value = value_node->value(); | |||||
| MS_EXCEPTION_IF_NULL(node_value); | |||||
| auto tensor = node_value->cast<tensor::TensorPtr>(); | |||||
| if (tensor == nullptr) { | |||||
| MS_LOG(EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert "; | |||||
| } | |||||
| auto shape_temp = tensor->shape(); | |||||
| (void)std::transform(shape_temp.begin(), shape_temp.end(), std::back_inserter(host_shape), IntToSize); | |||||
| if (host_shape.empty()) { | |||||
| host_shape.push_back(1); | |||||
| } | |||||
| } else { | |||||
| host_shape = AnfAlgo::GetOutputInferShape(node, index); | |||||
| } | |||||
| if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, 0), host_shape.size())) { | |||||
| host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, 0)); | |||||
| } | |||||
| std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt); | |||||
| return shape; | |||||
| } | |||||
| std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<kernel::Axis> &padding_axis) { | |||||
| if (padding_axis.empty() || shape.size() != padding_axis.size()) { | |||||
| return PaddingShapeTo4dByDefault(shape); | |||||
| } | |||||
| std::vector<size_t> shape_4d(kNchwDims, 1); | |||||
| for (size_t index = 0; index < padding_axis.size(); index++) { | |||||
| shape_4d[padding_axis[index]] = shape[index]; | |||||
| } | |||||
| return shape_4d; | |||||
| } | |||||
| std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) { | std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) { | ||||
| using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>; | using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>; | ||||
| const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape}, | const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape}, | ||||
| @@ -439,7 +441,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s | |||||
| device_shape.push_back(kCubeSize); | device_shape.push_back(kCubeSize); | ||||
| return device_shape; | return device_shape; | ||||
| } | } | ||||
| if (shape.size() != 4) { | |||||
| if (shape.size() != kNchwDims) { | |||||
| MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; | MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; | ||||
| temp_shape = PaddingShapeTo4dByDefault(shape); | temp_shape = PaddingShapeTo4dByDefault(shape); | ||||
| } | } | ||||
| @@ -455,6 +457,8 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) { | |||||
| MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | ||||
| return false; | return false; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(size); | |||||
| MS_EXCEPTION_IF_NULL(total_size); | |||||
| *size = TypeIdSize(args.src_data_type); | *size = TypeIdSize(args.src_data_type); | ||||
| if (*size < 1) { | if (*size < 1) { | ||||
| MS_LOG(ERROR) << "Illegal dtype."; | MS_LOG(ERROR) << "Illegal dtype."; | ||||
| @@ -540,10 +544,10 @@ bool NchwTo4D(const FormatArgs &args, void *result) { | |||||
| MS_LOG(ERROR) << "Check args failed."; | MS_LOG(ERROR) << "Check args failed."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| size_t n = args.host_shape[0]; | |||||
| size_t c = args.host_shape[1]; | |||||
| size_t h = args.host_shape[2]; | |||||
| size_t w = args.host_shape[3]; | |||||
| auto n = args.host_shape[kN]; | |||||
| auto c = args.host_shape[kC]; | |||||
| auto h = args.host_shape[kH]; | |||||
| auto w = args.host_shape[kW]; | |||||
| for (size_t ni = 0; ni < n; ni++) { | for (size_t ni = 0; ni < n; ni++) { | ||||
| for (size_t ci = 0; ci < c; ci++) { | for (size_t ci = 0; ci < c; ci++) { | ||||
| for (size_t hi = 0; hi < h; hi++) { | for (size_t hi = 0; hi < h; hi++) { | ||||
| @@ -572,10 +576,10 @@ bool ToNchw(const FormatArgs &args, void *result) { | |||||
| MS_LOG(ERROR) << "Check args failed."; | MS_LOG(ERROR) << "Check args failed."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| size_t n = args.host_shape[0]; | |||||
| size_t c = args.host_shape[1]; | |||||
| size_t h = args.host_shape[2]; | |||||
| size_t w = args.host_shape[3]; | |||||
| auto n = args.host_shape[kN]; | |||||
| auto c = args.host_shape[kC]; | |||||
| auto h = args.host_shape[kH]; | |||||
| auto w = args.host_shape[kW]; | |||||
| for (size_t ni = 0; ni < n; ni++) { | for (size_t ni = 0; ni < n; ni++) { | ||||
| for (size_t ci = 0; ci < c; ci++) { | for (size_t ci = 0; ci < c; ci++) { | ||||
| for (size_t hi = 0; hi < h; hi++) { | for (size_t hi = 0; hi < h; hi++) { | ||||
| @@ -602,32 +606,32 @@ bool NchwToFracZ(const FormatArgs &args, void *result) { | |||||
| MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | ||||
| return false; | return false; | ||||
| } | } | ||||
| size_t size = TypeIdSize(args.src_data_type); | |||||
| auto size = TypeIdSize(args.src_data_type); | |||||
| if (size < 1) { | if (size < 1) { | ||||
| MS_LOG(ERROR) << "Illegal dtype."; | MS_LOG(ERROR) << "Illegal dtype."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto n = args.host_shape[0]; | |||||
| auto c = args.host_shape[1]; | |||||
| auto h = args.host_shape[2]; | |||||
| auto w = args.host_shape[3]; | |||||
| auto n = args.host_shape[kN]; | |||||
| auto c = args.host_shape[kC]; | |||||
| auto h = args.host_shape[kH]; | |||||
| auto w = args.host_shape[kW]; | |||||
| size_t c0 = CubeSizeByType(args.src_data_type); | |||||
| auto c0 = CubeSizeByType(args.src_data_type); | |||||
| if (c0 < 1) { | if (c0 < 1) { | ||||
| MS_LOG(ERROR) << "Illegal dtype."; | MS_LOG(ERROR) << "Illegal dtype."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| size_t c1 = DivCeil(c, c0); | |||||
| size_t hw = h * w; | |||||
| size_t chw = c * hw; | |||||
| size_t hwc0 = hw * c0; | |||||
| size_t nchw = n * chw; | |||||
| size_t hf_cnt = DivCeil(n, kCubeSize); | |||||
| size_t vf_cnt = c1 * hw; | |||||
| size_t fractal_ele_cnt = c0 * kCubeSize; | |||||
| size_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | |||||
| size_t dst_size = total_ele_cnt * size; | |||||
| auto c1 = DivCeil(c, c0); | |||||
| auto hw = h * w; | |||||
| auto chw = c * hw; | |||||
| auto hwc0 = hw * c0; | |||||
| auto nchw = n * chw; | |||||
| auto hf_cnt = DivCeil(n, kCubeSize); | |||||
| auto vf_cnt = c1 * hw; | |||||
| auto fractal_ele_cnt = c0 * kCubeSize; | |||||
| auto total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | |||||
| auto dst_size = total_ele_cnt * size; | |||||
| if (dst_size != args.device_size) { | if (dst_size != args.device_size) { | ||||
| MS_LOG(ERROR) << "Illegal total data size." | MS_LOG(ERROR) << "Illegal total data size." | ||||
| << "dst size is :" << dst_size << "device size is :" << args.device_size; | << "dst size is :" << dst_size << "device size is :" << args.device_size; | ||||
| @@ -647,7 +651,7 @@ bool NchwToFracZ(const FormatArgs &args, void *result) { | |||||
| auto src_ni = hfi * kCubeSize + col; | auto src_ni = hfi * kCubeSize + col; | ||||
| auto src_idx = src_row_offset + chw * col; | auto src_idx = src_row_offset + chw * col; | ||||
| auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row; | auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row; | ||||
| auto pad_zero = (src_ni >= n || src_idx >= nchw || src_ci >= c) ? true : false; | |||||
| auto pad_zero = src_ni >= n || src_idx >= nchw || src_ci >= c; | |||||
| SetData(size, pad_zero, src_idx, dst_idx, args, result); | SetData(size, pad_zero, src_idx, dst_idx, args, result); | ||||
| } | } | ||||
| } | } | ||||
| @@ -663,12 +667,12 @@ bool FracZToNchw(const FormatArgs &args, void *result) { | |||||
| MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | ||||
| return false; | return false; | ||||
| } | } | ||||
| size_t size = TypeIdSize(args.src_data_type); | |||||
| auto size = TypeIdSize(args.src_data_type); | |||||
| if (size < 1) { | if (size < 1) { | ||||
| MS_LOG(ERROR) << "Illegal dtype."; | MS_LOG(ERROR) << "Illegal dtype."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| size_t total_size = ShapeSize(args.device_shape) * size; | |||||
| auto total_size = ShapeSize(args.device_shape) * size; | |||||
| if (total_size != args.device_size) { | if (total_size != args.device_size) { | ||||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; | MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; | ||||
| return false; | return false; | ||||
| @@ -677,18 +681,16 @@ bool FracZToNchw(const FormatArgs &args, void *result) { | |||||
| auto n0 = args.device_shape.at(1); | auto n0 = args.device_shape.at(1); | ||||
| auto ni = args.device_shape.at(2); | auto ni = args.device_shape.at(2); | ||||
| auto c0 = args.device_shape.at(3); | auto c0 = args.device_shape.at(3); | ||||
| auto n = args.host_shape[0]; | |||||
| auto c = args.host_shape[1]; | |||||
| auto h = args.host_shape[2]; | |||||
| auto w = args.host_shape[3]; | |||||
| size_t nc = ni * n0; | |||||
| size_t ncc0 = nc * c0; | |||||
| size_t wncc0 = w * ncc0; | |||||
| size_t hwncc0 = h * wncc0; | |||||
| size_t hw = h * w; | |||||
| size_t chw = c * hw; | |||||
| auto n = args.host_shape[kN]; | |||||
| auto c = args.host_shape[kC]; | |||||
| auto h = args.host_shape[kH]; | |||||
| auto w = args.host_shape[kW]; | |||||
| auto nc = ni * n0; | |||||
| auto ncc0 = nc * c0; | |||||
| auto wncc0 = w * ncc0; | |||||
| auto hwncc0 = h * wncc0; | |||||
| auto hw = h * w; | |||||
| auto chw = c * hw; | |||||
| for (size_t n_idx = 0; n_idx < n; n_idx++) { | for (size_t n_idx = 0; n_idx < n; n_idx++) { | ||||
| size_t n_head_addr = n_idx * chw; | size_t n_head_addr = n_idx * chw; | ||||
| @@ -720,20 +722,18 @@ bool NchwToFracZc04(const FormatArgs &args, void *result) { | |||||
| MS_LOG(ERROR) << "Check args failed."; | MS_LOG(ERROR) << "Check args failed."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| size_t cube = kCubeSize; | |||||
| size_t n = args.host_shape[0]; | |||||
| size_t c = args.host_shape[1]; | |||||
| size_t h = args.host_shape[2]; | |||||
| size_t w = args.host_shape[3]; | |||||
| size_t c0 = 4; | |||||
| size_t c1 = DivCeil(c, c0); | |||||
| size_t hwc0 = h * w * c0; | |||||
| size_t hwc = h * w * c; | |||||
| size_t nhwc = n * h * w * c; | |||||
| size_t n_cnt = DivCeil(n, cube); | |||||
| size_t v_cnt = DivCeil(h * w * c0 * c1, cube); | |||||
| auto cube = kCubeSize; | |||||
| auto n = args.host_shape[kN]; | |||||
| auto c = args.host_shape[kC]; | |||||
| auto h = args.host_shape[kH]; | |||||
| auto w = args.host_shape[kW]; | |||||
| const size_t c0 = 4; | |||||
| auto c1 = DivCeil(c, c0); | |||||
| auto hwc0 = h * w * c0; | |||||
| auto hwc = h * w * c; | |||||
| auto nhwc = n * h * w * c; | |||||
| auto n_cnt = DivCeil(n, cube); | |||||
| auto v_cnt = DivCeil(h * w * c0 * c1, cube); | |||||
| size_t dst_idx = 0; | size_t dst_idx = 0; | ||||
| for (size_t vi = 0; vi < v_cnt; vi++) { | for (size_t vi = 0; vi < v_cnt; vi++) { | ||||
| @@ -929,7 +929,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) { | |||||
| MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | ||||
| return false; | return false; | ||||
| } | } | ||||
| size_t size = TypeIdSize(args.src_data_type); | |||||
| auto size = TypeIdSize(args.src_data_type); | |||||
| if (size < 1) { | if (size < 1) { | ||||
| MS_LOG(ERROR) << "Illegal dtype."; | MS_LOG(ERROR) << "Illegal dtype."; | ||||
| return false; | return false; | ||||
| @@ -940,20 +940,23 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto n = args.host_shape[0]; | |||||
| auto c = args.host_shape[1]; | |||||
| auto h = args.host_shape[2]; | |||||
| auto w = args.host_shape[3]; | |||||
| size_t c0 = CubeSizeByType(args.src_data_type); | |||||
| auto n = args.host_shape[kN]; | |||||
| auto c = args.host_shape[kC]; | |||||
| auto h = args.host_shape[kH]; | |||||
| auto w = args.host_shape[kW]; | |||||
| auto c0 = CubeSizeByType(args.src_data_type); | |||||
| if (c0 < 1) { | if (c0 < 1) { | ||||
| MS_LOG(ERROR) << "Illegal dtype."; | MS_LOG(ERROR) << "Illegal dtype."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| size_t c1 = DivCeil(c, c0); | |||||
| size_t hw = h * w; | |||||
| size_t chw = c * hw; | |||||
| size_t c1hwc0 = c1 * hw * c0; | |||||
| size_t wc0 = w * c0; | |||||
| if (args.device_format == kOpFormat_NC1HWC0_C04) { | |||||
| c0 = 4; | |||||
| } | |||||
| auto c1 = DivCeil(c, c0); | |||||
| auto hw = h * w; | |||||
| auto chw = c * hw; | |||||
| auto c1hwc0 = c1 * hw * c0; | |||||
| auto wc0 = w * c0; | |||||
| for (size_t n_idx = 0; n_idx < n; n_idx++) { | for (size_t n_idx = 0; n_idx < n; n_idx++) { | ||||
| size_t n_head_addr = n_idx * c1hwc0; | size_t n_head_addr = n_idx * c1hwc0; | ||||
| @@ -967,7 +970,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) { | |||||
| size_t dst_idx = c0_idx + w_head_addr; | size_t dst_idx = c0_idx + w_head_addr; | ||||
| size_t c_idx = c0_idx + c1_idx * c0; | size_t c_idx = c0_idx + c1_idx * c0; | ||||
| size_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx; | size_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx; | ||||
| auto pad_zero = (c_idx < c) ? false : true; | |||||
| auto pad_zero = c_idx >= c; | |||||
| SetData(size, pad_zero, src_idx, dst_idx, args, result); | SetData(size, pad_zero, src_idx, dst_idx, args, result); | ||||
| } | } | ||||
| } | } | ||||
| @@ -984,29 +987,29 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) { | |||||
| MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | ||||
| return false; | return false; | ||||
| } | } | ||||
| size_t size = TypeIdSize(args.src_data_type); | |||||
| auto size = TypeIdSize(args.src_data_type); | |||||
| if (size < 1) { | if (size < 1) { | ||||
| MS_LOG(ERROR) << "Illegal dtype."; | MS_LOG(ERROR) << "Illegal dtype."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| size_t total_size = ShapeSize(args.device_shape) * size; | |||||
| auto total_size = ShapeSize(args.device_shape) * size; | |||||
| if (total_size != args.device_size) { | if (total_size != args.device_size) { | ||||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; | MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto n = args.host_shape[0]; | |||||
| auto c = args.host_shape[1]; | |||||
| auto h = args.host_shape[2]; | |||||
| auto w = args.host_shape[3]; | |||||
| auto n = args.host_shape[kN]; | |||||
| auto c = args.host_shape[kC]; | |||||
| auto h = args.host_shape[kH]; | |||||
| auto w = args.host_shape[kW]; | |||||
| auto c1 = args.device_shape[1]; | auto c1 = args.device_shape[1]; | ||||
| auto c0 = args.device_shape[4]; | auto c0 = args.device_shape[4]; | ||||
| size_t hw = h * w; | |||||
| size_t chw = c * hw; | |||||
| size_t wc0 = w * c0; | |||||
| size_t hwc0 = h * wc0; | |||||
| size_t c1hwc0 = c1 * hwc0; | |||||
| auto hw = h * w; | |||||
| auto chw = c * hw; | |||||
| auto wc0 = w * c0; | |||||
| auto hwc0 = h * wc0; | |||||
| auto c1hwc0 = c1 * hwc0; | |||||
| for (size_t n_idx = 0; n_idx < n; n_idx++) { | for (size_t n_idx = 0; n_idx < n; n_idx++) { | ||||
| size_t n_head_addr = n_idx * chw; | size_t n_head_addr = n_idx * chw; | ||||
| @@ -1037,13 +1040,15 @@ bool NchwToC1hwncoc0(const FormatArgs &args, void *result) { | |||||
| MS_LOG(ERROR) << "Check args failed."; | MS_LOG(ERROR) << "Check args failed."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto n = args.host_shape[0]; | |||||
| auto c = args.host_shape[1]; | |||||
| auto h = args.host_shape[2]; | |||||
| auto w = args.host_shape[3]; | |||||
| auto n = args.host_shape[kN]; | |||||
| auto c = args.host_shape[kC]; | |||||
| auto h = args.host_shape[kH]; | |||||
| auto w = args.host_shape[kW]; | |||||
| const int co_idx = 4; | |||||
| const int c0_idx = 5; | |||||
| auto c1 = args.device_shape[0]; | auto c1 = args.device_shape[0]; | ||||
| auto co = args.device_shape[4]; | |||||
| auto c0 = args.device_shape[5]; | |||||
| auto co = args.device_shape[co_idx]; | |||||
| auto c0 = args.device_shape[c0_idx]; | |||||
| for (size_t c1_i = 0; c1_i < c1; c1_i++) { | for (size_t c1_i = 0; c1_i < c1; c1_i++) { | ||||
| for (size_t h_i = 0; h_i < h; h_i++) { | for (size_t h_i = 0; h_i < h; h_i++) { | ||||
| @@ -1055,7 +1060,7 @@ bool NchwToC1hwncoc0(const FormatArgs &args, void *result) { | |||||
| co_i * c0 + c0_i; | co_i * c0 + c0_i; | ||||
| size_t c_i = c0_i + c1_i * c0; | size_t c_i = c0_i + c1_i * c0; | ||||
| size_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i; | size_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i; | ||||
| auto pad_zero = (c_i < c && c0_i == co_i) ? false : true; | |||||
| auto pad_zero = !(c_i < c && c0_i == co_i); | |||||
| SetData(size, pad_zero, src_idx, dst_idx, args, result); | SetData(size, pad_zero, src_idx, dst_idx, args, result); | ||||
| } | } | ||||
| } | } | ||||
| @@ -1076,12 +1081,14 @@ bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) { | |||||
| MS_LOG(ERROR) << "Check args failed."; | MS_LOG(ERROR) << "Check args failed."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto n = args.host_shape[0]; | |||||
| auto c = args.host_shape[1]; | |||||
| auto h = args.host_shape[2]; | |||||
| auto w = args.host_shape[3]; | |||||
| auto co = args.device_shape[4]; | |||||
| auto c0 = args.device_shape[5]; | |||||
| auto n = args.host_shape[kN]; | |||||
| auto c = args.host_shape[kC]; | |||||
| auto h = args.host_shape[kH]; | |||||
| auto w = args.host_shape[kW]; | |||||
| const int co_idx = 4; | |||||
| const int c0_idx = 5; | |||||
| auto co = args.device_shape[co_idx]; | |||||
| auto c0 = args.device_shape[c0_idx]; | |||||
| for (size_t n_i = 0; n_i < n; n_i++) { | for (size_t n_i = 0; n_i < n; n_i++) { | ||||
| for (size_t c_i = 0; c_i < c; c_i++) { | for (size_t c_i = 0; c_i < c; c_i++) { | ||||
| for (size_t h_i = 0; h_i < h; h_i++) { | for (size_t h_i = 0; h_i < h; h_i++) { | ||||