diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 299fc25550..2cf7cff113 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -554,8 +554,7 @@ std::vector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr & if (trans::IsNeedPadding(format, infer_shape.size())) { infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx)); } - auto dtype = AnfAlgo::GetOutputDeviceDataType(node, output_idx); - return trans::TransShapeToDevice(infer_shape, format, dtype); + return trans::TransShapeToDevice(infer_shape, format); } std::vector AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) { @@ -568,8 +567,7 @@ std::vector AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n if (trans::IsNeedPadding(format, infer_shape.size())) { infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx)); } - auto dtype = AnfAlgo::GetInputDeviceDataType(node, input_idx); - return trans::TransShapeToDevice(infer_shape, format, dtype); + return trans::TransShapeToDevice(infer_shape, format); } std::vector AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { @@ -1614,8 +1612,7 @@ std::vector AnfRuntimeAlgorithm::GetInputRealDeviceShapeIfExist(const An auto max_shape = GetInputMaxShape(anf_node, index); std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize); auto format = GetInputFormat(anf_node, index); - auto dtype = GetInputDeviceDataType(anf_node, index); - trans::TransShapeToDevice(device_shape, format, dtype); + trans::TransShapeToDevice(device_shape, format); } return device_shape; } @@ -1627,8 +1624,7 @@ std::vector AnfRuntimeAlgorithm::GetOutputRealDeviceShapeIfExist(const A auto max_shape = GetOutputMaxShape(anf_node, index); std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize); auto format = GetOutputFormat(anf_node, index); - auto dtype = GetOutputDeviceDataType(anf_node, index); - trans::TransShapeToDevice(device_shape, format, dtype); + trans::TransShapeToDevice(device_shape, format); } return device_shape; } diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 63203f89aa..bc3d5096e1 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -189,7 +189,7 @@ size_t CubeSizeByType(const TypeId data_type) { const size_t default_error = 0; auto dt_size = abstract::TypeIdSize(data_type); if (dt_size < 1) { - MS_LOG(EXCEPTION) << "Illegal dtype."; + MS_LOG(ERROR) << "Illegal dtype."; return default_error; } else if (dt_size == 1) { return kCubeSize * 2; @@ -206,14 +206,14 @@ bool CheckDims(const std::vector &shape) { return true; } -std::vector NchwDeviceShape(const std::vector &shape, const TypeId &type) { +std::vector NchwDeviceShape(const std::vector &shape) { if (!CheckDims(shape)) { MS_LOG(EXCEPTION) << "Check dims failed."; } return shape; } -std::vector NhwcDeviceShape(const std::vector &shape, const TypeId &type) { +std::vector NhwcDeviceShape(const std::vector &shape) { if (!CheckDims(shape)) { MS_LOG(EXCEPTION) << "Ccheck dims failed."; } @@ -225,7 +225,7 @@ std::vector NhwcDeviceShape(const std::vector &shape, const Type return device_shape; } -std::vector HwchDeviceShape(const std::vector &shape, const TypeId &type) { +std::vector HwchDeviceShape(const std::vector &shape) { if (!CheckDims(shape)) { MS_LOG(EXCEPTION) << "Check dims failed."; } @@ -237,29 +237,27 @@ std::vector HwchDeviceShape(const std::vector &shape, const Type return device_shape; } -std::vector FracZDeviceShape(const std::vector &shape, const TypeId &type) { +std::vector FracZDeviceShape(const std::vector &shape) { if (!CheckDims(shape)) { MS_LOG(EXCEPTION) << "Check dims failed."; } - auto kCube = CubeSizeByType(type); std::vector device_shape; - auto c1 = DivCeil(shape[kC], kCube); - auto n0 = DivCeil(shape[kN], kCubeSize); - device_shape.push_back(shape[kH] * shape[kW] * c1); - device_shape.push_back(n0); + const size_t cout16 = ((shape[kN] + kCubeSize - 1) / kCubeSize) * kCubeSize; + const size_t cin16 = ((shape[kC] + kCubeSize - 1) / kCubeSize) * kCubeSize; + device_shape.push_back(shape[kH] * shape[kW] * cin16 / kCubeSize); + device_shape.push_back(cout16 / kCubeSize); + device_shape.push_back(kCubeSize); device_shape.push_back(kCubeSize); - device_shape.push_back(kCube); return device_shape; } -std::vector Nc1hwc0DeviceShape(const std::vector &shape, const TypeId &type) { +std::vector Nc1hwc0DeviceShape(const std::vector &shape) { if (!CheckDims(shape)) { MS_LOG(EXCEPTION) << "Check dims failed."; } - auto kCube = CubeSizeByType(type); std::vector device_shape; - const size_t C1 = (shape[kC] + kCube - 1) / kCube; - const size_t C0 = kCube; + const size_t C1 = (shape[kC] + kCubeSize - 1) / kCubeSize; + const size_t C0 = kCubeSize; device_shape.push_back(shape[kN]); device_shape.push_back(C1); device_shape.push_back(shape[kH]); @@ -268,7 +266,7 @@ std::vector Nc1hwc0DeviceShape(const std::vector &shape, const T return device_shape; } -std::vector Ndc1hwc0DeviceShape(const std::vector &shape, const TypeId &type) { +std::vector Ndc1hwc0DeviceShape(const std::vector &shape) { // NCDHW if (shape.size() != 5) { MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size(); @@ -285,54 +283,51 @@ std::vector Ndc1hwc0DeviceShape(const std::vector &shape, const return device_shape; } -std::vector Fracz3DDeviceShape(const std::vector &shape, const TypeId &type) { +std::vector Fracz3DDeviceShape(const std::vector &shape) { // NCDHW -> Frac_Z_3D if (shape.size() != 5) { MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size(); } - auto kCube = CubeSizeByType(type); std::vector device_shape; - const size_t C1 = (shape[1] + kCube - 1) / kCube; + const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize; device_shape.push_back(shape[2] * C1 * shape[3] * shape[4]); device_shape.push_back(N1); device_shape.push_back(kCubeSize); - device_shape.push_back(kCube); + device_shape.push_back(kCubeSize); return device_shape; } -std::vector C1hwncoc0DeviceShape(const std::vector &shape, const TypeId &type) { +std::vector C1hwncoc0DeviceShape(const std::vector &shape) { if (!CheckDims(shape)) { MS_LOG(EXCEPTION) << "Check dims failed."; } - auto kCube = CubeSizeByType(type); std::vector device_shape; - device_shape.push_back((shape[kC] - 1) / kCube + 1); + device_shape.push_back((shape[kC] - 1) / kCubeSize + 1); device_shape.push_back(shape[kH]); device_shape.push_back(shape[kW]); device_shape.push_back(shape[kN]); - device_shape.push_back(kCube); - device_shape.push_back(kCube); + device_shape.push_back(kCubeSize); + device_shape.push_back(kCubeSize); return device_shape; } -std::vector FracZc04DeviceShape(const std::vector &shape, const TypeId &type) { +std::vector FracZc04DeviceShape(const std::vector &shape) { if (!CheckDims(shape)) { MS_LOG(EXCEPTION) << "Check dims failed."; } - auto kCube = CubeSizeByType(type); std::vector device_shape; const size_t c0 = 4; - auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCube); - auto no = DivCeil(shape.at(kN), kCube); + auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCubeSize); + auto no = DivCeil(shape.at(kN), kCubeSize); device_shape.push_back(first_dim); device_shape.push_back(no); - device_shape.push_back(kCube); - device_shape.push_back(kCube); + device_shape.push_back(kCubeSize); + device_shape.push_back(kCubeSize); return device_shape; } -std::vector Nc1hwc04DeviceShape(const std::vector &shape, const TypeId &type) { +std::vector Nc1hwc04DeviceShape(const std::vector &shape) { if (!CheckDims(shape)) { MS_LOG(EXCEPTION) << "Check dims failed."; } @@ -347,7 +342,7 @@ std::vector Nc1hwc04DeviceShape(const std::vector &shape, const return device_shape; } -std::vector NcdhwDeviceShape(const std::vector &shape, const TypeId &type) { +std::vector NcdhwDeviceShape(const std::vector &shape) { if (shape.size() < kNdhwc) { MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc."; } @@ -432,9 +427,8 @@ std::vector PaddingShapeTo4d(const std::vector &shape, const std return shape_4d; } -std::vector TransShapeToDevice(const std::vector &shape, const std::string &format, - const TypeId &type) { - using DeviceShapeTransfer = std::function(const std::vector &, const TypeId &)>; +std::vector TransShapeToDevice(const std::vector &shape, const std::string &format) { + using DeviceShapeTransfer = std::function(const std::vector &)>; const std::map device_shape_map{{kOpFormat_NCHW, NchwDeviceShape}, {kOpFormat_NHWC, NhwcDeviceShape}, {kOpFormat_HWCN, HwchDeviceShape}, @@ -452,9 +446,8 @@ std::vector TransShapeToDevice(const std::vector &shape, const s } auto temp_shape = shape; std::vector device_shape; - auto kCube = CubeSizeByType(type); if (format == kOpFormat_FRAC_NZ) { - if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCube == 0)) { + if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) { // For [1] and [1024] shape we can trait it as NZ shape return shape; } @@ -463,12 +456,12 @@ std::vector TransShapeToDevice(const std::vector &shape, const s } else { (void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape)); } - auto w1 = (shape[shape.size() - 1] - 1) / kCube + 1; auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1; + auto w1 = (shape[shape.size() - 1] - 1) / kCubeSize + 1; device_shape.push_back(w1); device_shape.push_back(h1); device_shape.push_back(kCubeSize); - device_shape.push_back(kCube); + device_shape.push_back(kCubeSize); return device_shape; } else if (format == kOpFormat_FRACTAL_ZN_LSTM) { const size_t c0 = 4; @@ -490,7 +483,7 @@ std::vector TransShapeToDevice(const std::vector &shape, const s if (iter == device_shape_map.end()) { MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]"; } - return iter->second(temp_shape, type); + return iter->second(temp_shape); } bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) { diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index 8255dd8a34..9014f3c051 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -53,7 +53,7 @@ size_t CubeSizeByType(const TypeId data_type); std::vector PaddingShapeTo4d(const std::vector &shape, const std::vector &padding_axis = {}); ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index); bool IsNeedPadding(const std::string &format, const size_t shape_size); -std::vector TransShapeToDevice(const std::vector &shape, const std::string &format, const TypeId &type); +std::vector TransShapeToDevice(const std::vector &shape, const std::string &format); bool TransDataType(const TypeIdArgs &args, void *result); bool TransFormat(const FormatArgs &args, void *result); bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc index 034f0baa3f..9c669fd510 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -455,7 +455,7 @@ std::vector AscendDeviceAddress::GetWorkspaceSizeList(const nlohmann::js std::vector AscendDeviceAddress::GetDeviceShape(std::vector *host_shape) const { std::vector device_shape; if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) { - device_shape = trans::TransShapeToDevice(*host_shape, format_, type_id_); + device_shape = trans::TransShapeToDevice(*host_shape, format_); } else { if (host_shape_.empty()) { *host_shape = trans::PaddingShapeTo4d(*host_shape); @@ -463,7 +463,7 @@ std::vector AscendDeviceAddress::GetDeviceShape(std::vector *hos host_shape->clear(); (void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(*host_shape), LongToSize); } - device_shape = trans::TransShapeToDevice(*host_shape, format_, type_id_); + device_shape = trans::TransShapeToDevice(*host_shape, format_); } return device_shape; } @@ -577,10 +577,10 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh std::vector device_shape; if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW || format_ == kOpFormat_NDC1HWC0 || format_ == kOpFormat_FRACTAL_Z_3D) { - device_shape = trans::TransShapeToDevice(host_shape, format_, type_id_); + device_shape = trans::TransShapeToDevice(host_shape, format_); } else { host_shape = trans::PaddingShapeTo4d(host_shape); - device_shape = trans::TransShapeToDevice(host_shape, format_, type_id_); + device_shape = trans::TransShapeToDevice(host_shape, format_); } if (type_id_ != type) { auto shape_size = abstract::ShapeSize(host_shape); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index d78ebcd134..2f1373c7ec 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -69,8 +69,7 @@ size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &nod auto format = AnfAlgo::GetOutputFormat(node, output_index); if (shape.empty() && format != kOpFormat_DEFAULT) { shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index)); - auto dtype = AnfAlgo::GetOutputDeviceDataType(node, output_index); - shape = trans::TransShapeToDevice(shape, format, dtype); + shape = trans::TransShapeToDevice(shape, format); } // scalar's output shape is a empty vector size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies());