| @@ -17,6 +17,7 @@ | |||||
| #include "backend/kernel_compiler/rts/memcpy_async.h" | #include "backend/kernel_compiler/rts/memcpy_async.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include "abstract/utils.h" | |||||
| #include "runtime/mem.h" | #include "runtime/mem.h" | ||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "common/trans.h" | #include "common/trans.h" | ||||
| @@ -89,7 +90,7 @@ void MemCpyAsyncKernel::GetInputOutputTotalCount(const AnfNodePtr &anf_node) { | |||||
| if (input_size != 1) { | if (input_size != 1) { | ||||
| MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1"; | MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1"; | ||||
| } | } | ||||
| size_t type_size = trans::TypeIdSize(input_type_id_); | |||||
| size_t type_size = abstract::TypeIdSize(input_type_id_); | |||||
| std::vector<size_t> shape_i = AnfAlgo::GetInputDeviceShape(anf_node, 0); | std::vector<size_t> shape_i = AnfAlgo::GetInputDeviceShape(anf_node, 0); | ||||
| size_t total_size = 1; | size_t total_size = 1; | ||||
| for (size_t i = 0; i < shape_i.size(); i++) { | for (size_t i = 0; i < shape_i.size(); i++) { | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include "c_ops/primitive_c.h" | #include "c_ops/primitive_c.h" | ||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| #include "abstract/utils.h" | |||||
| #include "backend/kernel_compiler/common_utils.h" | #include "backend/kernel_compiler/common_utils.h" | ||||
| #include "base/core_ops.h" | #include "base/core_ops.h" | ||||
| #include "common/trans.h" | #include "common/trans.h" | ||||
| @@ -1093,7 +1094,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap | |||||
| (void)std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape_tmp), IntToSize); | (void)std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape_tmp), IntToSize); | ||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape_tmp}, | AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape_tmp}, | ||||
| input_node.get()); | input_node.get()); | ||||
| size = trans::ShapeSize(shape_tmp) * trans::TypeIdSize(tensor->data_type()); | |||||
| size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type()); | |||||
| } | } | ||||
| if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { | if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { | ||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); | auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <numeric> | #include <numeric> | ||||
| #include <utility> | #include <utility> | ||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "abstract/utils.h" | |||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "backend/kernel_compiler/kernel.h" | #include "backend/kernel_compiler/kernel.h" | ||||
| #include "runtime/device/convert_tensor_utils.h" | #include "runtime/device/convert_tensor_utils.h" | ||||
| @@ -28,12 +29,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace trans { | namespace trans { | ||||
| enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc }; | enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc }; | ||||
| const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1}, | |||||
| {kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8}, | |||||
| {kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2}, | |||||
| {kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4}, | |||||
| {kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}}; | |||||
| inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) { | inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) { | ||||
| switch (size) { | switch (size) { | ||||
| case 1: | case 1: | ||||
| @@ -117,8 +112,8 @@ const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{ | |||||
| {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), FROM_BOOL_TO_FLOAT16}}; | {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), FROM_BOOL_TO_FLOAT16}}; | ||||
| void CheckMemSize(const TypeIdArgs &args) { | void CheckMemSize(const TypeIdArgs &args) { | ||||
| auto src_type_size = TypeIdSize(args.host_data_type); | |||||
| auto dst_type_size = TypeIdSize(args.device_data_type); | |||||
| auto src_type_size = abstract::TypeIdSize(args.host_data_type); | |||||
| auto dst_type_size = abstract::TypeIdSize(args.device_data_type); | |||||
| if (src_type_size < 1 || dst_type_size < 1) { | if (src_type_size < 1 || dst_type_size < 1) { | ||||
| MS_LOG(EXCEPTION) << "Invalid src or dst data type."; | MS_LOG(EXCEPTION) << "Invalid src or dst data type."; | ||||
| } | } | ||||
| @@ -192,7 +187,7 @@ bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const | |||||
| size_t CubeSizeByType(const TypeId data_type) { | size_t CubeSizeByType(const TypeId data_type) { | ||||
| const size_t default_error = 0; | const size_t default_error = 0; | ||||
| auto dt_size = TypeIdSize(data_type); | |||||
| auto dt_size = abstract::TypeIdSize(data_type); | |||||
| if (dt_size < 1) { | if (dt_size < 1) { | ||||
| MS_LOG(ERROR) << "Illegal dtype."; | MS_LOG(ERROR) << "Illegal dtype."; | ||||
| return default_error; | return default_error; | ||||
| @@ -202,19 +197,6 @@ size_t CubeSizeByType(const TypeId data_type) { | |||||
| return kCubeSize; | return kCubeSize; | ||||
| } | } | ||||
| size_t ShapeSize(const std::vector<size_t> &shape) { | |||||
| return std::accumulate(shape.begin(), shape.end(), IntToSize(1), std::multiplies<size_t>()); | |||||
| } | |||||
| size_t TypeIdSize(const TypeId data_type) { | |||||
| const size_t unsupported_type_error = 0; | |||||
| auto iter = type_map.find(data_type); | |||||
| if (iter != type_map.end()) { | |||||
| return iter->second; | |||||
| } | |||||
| return unsupported_type_error; | |||||
| } | |||||
| namespace { | namespace { | ||||
| bool CheckDims(const std::vector<size_t> &shape) { | bool CheckDims(const std::vector<size_t> &shape) { | ||||
| if (shape.size() != kNchwDims) { | if (shape.size() != kNchwDims) { | ||||
| @@ -477,12 +459,12 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) { | |||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(size); | MS_EXCEPTION_IF_NULL(size); | ||||
| MS_EXCEPTION_IF_NULL(total_size); | MS_EXCEPTION_IF_NULL(total_size); | ||||
| *size = TypeIdSize(args.src_data_type); | |||||
| *size = abstract::TypeIdSize(args.src_data_type); | |||||
| if (*size < 1) { | if (*size < 1) { | ||||
| MS_LOG(ERROR) << "Illegal dtype."; | MS_LOG(ERROR) << "Illegal dtype."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| *total_size = ShapeSize(args.device_shape) * (*size); | |||||
| *total_size = abstract::ShapeSize(args.device_shape) * (*size); | |||||
| if (*total_size != args.device_size) { | if (*total_size != args.device_size) { | ||||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << *total_size << ", device_size:" << args.device_size; | MS_LOG(ERROR) << "Illegal total data size, total_size:" << *total_size << ", device_size:" << args.device_size; | ||||
| return false; | return false; | ||||
| @@ -516,7 +498,7 @@ bool TransFormat(const FormatArgs &args, void *result) { | |||||
| {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, | {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, | ||||
| {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}}; | {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}}; | ||||
| MS_LOG(DEBUG) << "Start trans format."; | MS_LOG(DEBUG) << "Start trans format."; | ||||
| if (TypeIdSize(args.src_data_type) < 1) { | |||||
| if (abstract::TypeIdSize(args.src_data_type) < 1) { | |||||
| MS_LOG(ERROR) << "Invalid datatype.."; | MS_LOG(ERROR) << "Invalid datatype.."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -538,7 +520,7 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { | |||||
| {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, | {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, | ||||
| {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}}; | {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}}; | ||||
| MS_LOG(DEBUG) << "Start trans format."; | MS_LOG(DEBUG) << "Start trans format."; | ||||
| if (TypeIdSize(args.src_data_type) < 1) { | |||||
| if (abstract::TypeIdSize(args.src_data_type) < 1) { | |||||
| MS_LOG(ERROR) << "Invalid datatype.."; | MS_LOG(ERROR) << "Invalid datatype.."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -624,7 +606,7 @@ bool NchwToFracZ(const FormatArgs &args, void *result) { | |||||
| MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto size = TypeIdSize(args.src_data_type); | |||||
| auto size = abstract::TypeIdSize(args.src_data_type); | |||||
| if (size < 1) { | if (size < 1) { | ||||
| MS_LOG(ERROR) << "Illegal dtype."; | MS_LOG(ERROR) << "Illegal dtype."; | ||||
| return false; | return false; | ||||
| @@ -685,12 +667,12 @@ bool FracZToNchw(const FormatArgs &args, void *result) { | |||||
| MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto size = TypeIdSize(args.src_data_type); | |||||
| auto size = abstract::TypeIdSize(args.src_data_type); | |||||
| if (size < 1) { | if (size < 1) { | ||||
| MS_LOG(ERROR) << "Illegal dtype."; | MS_LOG(ERROR) << "Illegal dtype."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto total_size = ShapeSize(args.device_shape) * size; | |||||
| auto total_size = abstract::ShapeSize(args.device_shape) * size; | |||||
| if (total_size != args.device_size) { | if (total_size != args.device_size) { | ||||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; | MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; | ||||
| return false; | return false; | ||||
| @@ -828,13 +810,13 @@ bool NchwToFracNz(const FormatArgs &args, void *result) { | |||||
| MS_LOG(ERROR) << "Invalid shape size."; | MS_LOG(ERROR) << "Invalid shape size."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto size = TypeIdSize(args.src_data_type); | |||||
| auto size = abstract::TypeIdSize(args.src_data_type); | |||||
| if (size < 1) { | if (size < 1) { | ||||
| MS_LOG(ERROR) << "Illegal dtype"; | MS_LOG(ERROR) << "Illegal dtype"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto dst_size = ShapeSize(args.device_shape) * size; | |||||
| auto dst_size = abstract::ShapeSize(args.device_shape) * size; | |||||
| if (dst_size != args.device_size) { | if (dst_size != args.device_size) { | ||||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size; | MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size; | ||||
| return false; | return false; | ||||
| @@ -890,13 +872,13 @@ bool FracNzToNchw(const FormatArgs &args, void *result) { | |||||
| MS_LOG(ERROR) << "Invalid shape size."; | MS_LOG(ERROR) << "Invalid shape size."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto size = TypeIdSize(args.src_data_type); | |||||
| auto size = abstract::TypeIdSize(args.src_data_type); | |||||
| if (size < 1) { | if (size < 1) { | ||||
| MS_LOG(ERROR) << "Illegal dtype"; | MS_LOG(ERROR) << "Illegal dtype"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto dst_size = ShapeSize(args.device_shape) * size; | |||||
| auto dst_size = abstract::ShapeSize(args.device_shape) * size; | |||||
| if (dst_size != args.device_size) { | if (dst_size != args.device_size) { | ||||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size; | MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size; | ||||
| return false; | return false; | ||||
| @@ -947,12 +929,12 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) { | |||||
| MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto size = TypeIdSize(args.src_data_type); | |||||
| auto size = abstract::TypeIdSize(args.src_data_type); | |||||
| if (size < 1) { | if (size < 1) { | ||||
| MS_LOG(ERROR) << "Illegal dtype."; | MS_LOG(ERROR) << "Illegal dtype."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto total_size = ShapeSize(args.device_shape) * size; | |||||
| auto total_size = abstract::ShapeSize(args.device_shape) * size; | |||||
| if (total_size != args.device_size) { | if (total_size != args.device_size) { | ||||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; | MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; | ||||
| return false; | return false; | ||||
| @@ -1005,12 +987,12 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) { | |||||
| MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto size = TypeIdSize(args.src_data_type); | |||||
| auto size = abstract::TypeIdSize(args.src_data_type); | |||||
| if (size < 1) { | if (size < 1) { | ||||
| MS_LOG(ERROR) << "Illegal dtype."; | MS_LOG(ERROR) << "Illegal dtype."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto total_size = ShapeSize(args.device_shape) * size; | |||||
| auto total_size = abstract::ShapeSize(args.device_shape) * size; | |||||
| if (total_size != args.device_size) { | if (total_size != args.device_size) { | ||||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; | MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; | ||||
| return false; | return false; | ||||
| @@ -48,8 +48,6 @@ struct FormatArgs { | |||||
| TypeId src_data_type; | TypeId src_data_type; | ||||
| }; | }; | ||||
| size_t TypeIdSize(const TypeId data_type); | |||||
| size_t ShapeSize(const std::vector<size_t> &shape); | |||||
| size_t CubeSizeByType(const TypeId data_type); | size_t CubeSizeByType(const TypeId data_type); | ||||
| std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis = {}); | std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis = {}); | ||||
| @@ -26,6 +26,7 @@ | |||||
| #include "runtime/device/convert_tensor_utils.h" | #include "runtime/device/convert_tensor_utils.h" | ||||
| #include "ir/dtype/type.h" | #include "ir/dtype/type.h" | ||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| #include "abstract/utils.h" | |||||
| #include "backend/kernel_compiler/tbe/tbe_kernel_build.h" | #include "backend/kernel_compiler/tbe/tbe_kernel_build.h" | ||||
| #include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" | #include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" | ||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| @@ -298,7 +299,7 @@ bool AscendDeviceAddress::SyncDeviceToHost(const ShapeVector &shape, size_t size | |||||
| } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { | } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { | ||||
| sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_); | sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_); | ||||
| } else { | } else { | ||||
| auto shape_size = trans::ShapeSize(host_shape); | |||||
| auto shape_size = abstract::ShapeSize(host_shape); | |||||
| auto host = std::vector<uint8_t>(size_); | auto host = std::vector<uint8_t>(size_); | ||||
| SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); | SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); | ||||
| const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size_}; | const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size_}; | ||||
| @@ -413,11 +414,11 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod_ptr); | MS_EXCEPTION_IF_NULL(kernel_mod_ptr); | ||||
| auto host_size = size; | auto host_size = size; | ||||
| if (type_id_ != type) { | if (type_id_ != type) { | ||||
| auto device_dtype_size = trans::TypeIdSize(type_id_); | |||||
| auto device_dtype_size = abstract::TypeIdSize(type_id_); | |||||
| if (device_dtype_size < 1) { | if (device_dtype_size < 1) { | ||||
| MS_LOG(ERROR) << "Illegal dtype."; | MS_LOG(ERROR) << "Illegal dtype."; | ||||
| } | } | ||||
| auto shape_size = trans::ShapeSize(host_shape); | |||||
| auto shape_size = abstract::ShapeSize(host_shape); | |||||
| size = device_dtype_size * shape_size; | size = device_dtype_size * shape_size; | ||||
| } | } | ||||
| size = GetCommonAlignSize(size); | size = GetCommonAlignSize(size); | ||||
| @@ -431,7 +432,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const | |||||
| } else { | } else { | ||||
| auto host = std::vector<uint8_t>(size); | auto host = std::vector<uint8_t>(size); | ||||
| SyncMemory(host.data(), output_address->GetPtr(), size, RT_MEMCPY_DEVICE_TO_HOST); | SyncMemory(host.data(), output_address->GetPtr(), size, RT_MEMCPY_DEVICE_TO_HOST); | ||||
| auto shape_size = trans::ShapeSize(host_shape); | |||||
| auto shape_size = abstract::ShapeSize(host_shape); | |||||
| const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, host_size}; | const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, host_size}; | ||||
| sync_ok = trans::TransDataType(type_args, host_ptr); | sync_ok = trans::TransDataType(type_args, host_ptr); | ||||
| if (!sync_ok) { | if (!sync_ok) { | ||||
| @@ -500,7 +501,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh | |||||
| MS_LOG(ERROR) << "Trans format failed."; | MS_LOG(ERROR) << "Trans format failed."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto shape_size = trans::ShapeSize(host_shape); | |||||
| auto shape_size = abstract::ShapeSize(host_shape); | |||||
| const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size}; | const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size}; | ||||
| sync_ok = trans::TransDataType(type_args, host_ptr); | sync_ok = trans::TransDataType(type_args, host_ptr); | ||||
| if (!sync_ok) { | if (!sync_ok) { | ||||
| @@ -537,7 +538,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size | |||||
| } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { | } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { | ||||
| sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size); | sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size); | ||||
| } else { | } else { | ||||
| auto shape_size = trans::ShapeSize(host_shape); | |||||
| auto shape_size = abstract::ShapeSize(host_shape); | |||||
| const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size}; | const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size}; | ||||
| auto host_tmp = std::vector<uint8_t>(size_); | auto host_tmp = std::vector<uint8_t>(size_); | ||||
| sync_ok = trans::TransDataType(type_args, host_tmp.data()); | sync_ok = trans::TransDataType(type_args, host_tmp.data()); | ||||
| @@ -581,7 +582,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh | |||||
| device_shape = trans::TransShapeToDevice(host_shape, format_); | device_shape = trans::TransShapeToDevice(host_shape, format_); | ||||
| } | } | ||||
| if (type_id_ != type) { | if (type_id_ != type) { | ||||
| auto shape_size = trans::ShapeSize(host_shape); | |||||
| auto shape_size = abstract::ShapeSize(host_shape); | |||||
| const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size}; | const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size}; | ||||
| auto host_tmp = std::vector<uint8_t>(size_); | auto host_tmp = std::vector<uint8_t>(size_); | ||||
| sync_ok = trans::TransDataType(type_args, host_tmp.data()); | sync_ok = trans::TransDataType(type_args, host_tmp.data()); | ||||
| @@ -22,6 +22,8 @@ | |||||
| #include <exception> | #include <exception> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <thread> | #include <thread> | ||||
| #include <stack> | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| #include "debug/data_dump/e2e_dump_util.h" | #include "debug/data_dump/e2e_dump_util.h" | ||||
| #include "runtime/device/ascend/ascend_device_address.h" | #include "runtime/device/ascend/ascend_device_address.h" | ||||
| #include "runtime/device/cpu/mpi/mpi_interface.h" | #include "runtime/device/cpu/mpi/mpi_interface.h" | ||||
| @@ -39,6 +41,7 @@ | |||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "runtime/device/ascend/profiling/profiling_utils.h" | #include "runtime/device/ascend/profiling/profiling_utils.h" | ||||
| #include "backend/kernel_compiler/tbe/tbe_utils.h" | #include "backend/kernel_compiler/tbe/tbe_utils.h" | ||||
| #include "backend/optimizer/common/helper.h" | |||||
| #include "runtime/device/ascend/ascend_memory_manager.h" | #include "runtime/device/ascend/ascend_memory_manager.h" | ||||
| #include "debug/tensor_load.h" | #include "debug/tensor_load.h" | ||||
| #include "debug/data_dump/dump_json_parser.h" | #include "debug/data_dump/dump_json_parser.h" | ||||
| @@ -110,6 +113,34 @@ std::string GetRankId() { | |||||
| } | } | ||||
| return rank_id_str; | return rank_id_str; | ||||
| } | } | ||||
| void InferShapeForNopNode(AnfNodePtr *input_node) { | |||||
| MS_EXCEPTION_IF_NULL(*input_node); | |||||
| if (!opt::IsNopNode(*input_node)) { | |||||
| MS_LOG(INFO) << "Input node is not a nop node, no need infer."; | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Infer shape for nop node."; | |||||
| std::stack<AnfNodePtr> nop_road; | |||||
| nop_road.push(*input_node); | |||||
| while (true) { | |||||
| auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(*input_node, 0); | |||||
| auto in_node = input_node_with_idx.first; | |||||
| MS_EXCEPTION_IF_NULL(in_node); | |||||
| if (opt::IsNopNode(in_node)) { | |||||
| nop_road.push(in_node); | |||||
| *input_node = in_node; | |||||
| } else { | |||||
| break; | |||||
| } | |||||
| } | |||||
| while (!nop_road.empty()) { | |||||
| auto nop_node = nop_road.top(); | |||||
| AnfAlgo::InferShape(nop_node->cast<CNodePtr>()); | |||||
| nop_road.pop(); | |||||
| } | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| std::vector<rtExceptionInfo> AscendKernelRuntime::exception_infoes_; | std::vector<rtExceptionInfo> AscendKernelRuntime::exception_infoes_; | ||||
| @@ -633,6 +664,15 @@ bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph *grap | |||||
| } | } | ||||
| if (dynamic_kernel->is_dynamic_shape()) { | if (dynamic_kernel->is_dynamic_shape()) { | ||||
| auto kernel_node = dynamic_kernel->kernel_node(); | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| auto input_size = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| for (size_t i = 0; i < input_size; i++) { | |||||
| auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(kernel_node, i); | |||||
| auto input_node = input_node_with_index.first; | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| InferShapeForNopNode(&input_node); | |||||
| } | |||||
| dynamic_kernel->InferShape(); | dynamic_kernel->InferShape(); | ||||
| dynamic_kernel->UpdateArgs(); | dynamic_kernel->UpdateArgs(); | ||||
| } | } | ||||
| @@ -48,6 +48,7 @@ class DynamicKernel { | |||||
| virtual void Initialize(); | virtual void Initialize(); | ||||
| std::string GetKernelName() { return cnode_ptr_->fullname_with_scope(); } | std::string GetKernelName() { return cnode_ptr_->fullname_with_scope(); } | ||||
| int GetKernelType(); | int GetKernelType(); | ||||
| CNodePtr kernel_node() const { return cnode_ptr_; } | |||||
| protected: | protected: | ||||
| void RebuildDependTensor(); | void RebuildDependTensor(); | ||||