| @@ -17,6 +17,7 @@ | |||
| #include "backend/kernel_compiler/rts/memcpy_async.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include "abstract/utils.h" | |||
| #include "runtime/mem.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "common/trans.h" | |||
| @@ -89,7 +90,7 @@ void MemCpyAsyncKernel::GetInputOutputTotalCount(const AnfNodePtr &anf_node) { | |||
| if (input_size != 1) { | |||
| MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1"; | |||
| } | |||
| size_t type_size = trans::TypeIdSize(input_type_id_); | |||
| size_t type_size = abstract::TypeIdSize(input_type_id_); | |||
| std::vector<size_t> shape_i = AnfAlgo::GetInputDeviceShape(anf_node, 0); | |||
| size_t total_size = 1; | |||
| for (size_t i = 0; i < shape_i.size(); i++) { | |||
| @@ -20,6 +20,7 @@ | |||
| #include "c_ops/primitive_c.h" | |||
| #include "ir/manager.h" | |||
| #include "abstract/utils.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "base/core_ops.h" | |||
| #include "common/trans.h" | |||
| @@ -1093,7 +1094,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap | |||
| (void)std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape_tmp), IntToSize); | |||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape_tmp}, | |||
| input_node.get()); | |||
| size = trans::ShapeSize(shape_tmp) * trans::TypeIdSize(tensor->data_type()); | |||
| size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type()); | |||
| } | |||
| if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); | |||
| @@ -18,6 +18,7 @@ | |||
| #include <numeric> | |||
| #include <utility> | |||
| #include "utils/ms_utils.h" | |||
| #include "abstract/utils.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "runtime/device/convert_tensor_utils.h" | |||
| @@ -28,12 +29,6 @@ | |||
| namespace mindspore { | |||
| namespace trans { | |||
| enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc }; | |||
| const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1}, | |||
| {kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8}, | |||
| {kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2}, | |||
| {kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4}, | |||
| {kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}}; | |||
| inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) { | |||
| switch (size) { | |||
| case 1: | |||
| @@ -117,8 +112,8 @@ const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{ | |||
| {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), FROM_BOOL_TO_FLOAT16}}; | |||
| void CheckMemSize(const TypeIdArgs &args) { | |||
| auto src_type_size = TypeIdSize(args.host_data_type); | |||
| auto dst_type_size = TypeIdSize(args.device_data_type); | |||
| auto src_type_size = abstract::TypeIdSize(args.host_data_type); | |||
| auto dst_type_size = abstract::TypeIdSize(args.device_data_type); | |||
| if (src_type_size < 1 || dst_type_size < 1) { | |||
| MS_LOG(EXCEPTION) << "Invalid src or dst data type."; | |||
| } | |||
| @@ -192,7 +187,7 @@ bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const | |||
| size_t CubeSizeByType(const TypeId data_type) { | |||
| const size_t default_error = 0; | |||
| auto dt_size = TypeIdSize(data_type); | |||
| auto dt_size = abstract::TypeIdSize(data_type); | |||
| if (dt_size < 1) { | |||
| MS_LOG(ERROR) << "Illegal dtype."; | |||
| return default_error; | |||
| @@ -202,19 +197,6 @@ size_t CubeSizeByType(const TypeId data_type) { | |||
| return kCubeSize; | |||
| } | |||
| size_t ShapeSize(const std::vector<size_t> &shape) { | |||
| return std::accumulate(shape.begin(), shape.end(), IntToSize(1), std::multiplies<size_t>()); | |||
| } | |||
| size_t TypeIdSize(const TypeId data_type) { | |||
| const size_t unsupported_type_error = 0; | |||
| auto iter = type_map.find(data_type); | |||
| if (iter != type_map.end()) { | |||
| return iter->second; | |||
| } | |||
| return unsupported_type_error; | |||
| } | |||
| namespace { | |||
| bool CheckDims(const std::vector<size_t> &shape) { | |||
| if (shape.size() != kNchwDims) { | |||
| @@ -477,12 +459,12 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) { | |||
| } | |||
| MS_EXCEPTION_IF_NULL(size); | |||
| MS_EXCEPTION_IF_NULL(total_size); | |||
| *size = TypeIdSize(args.src_data_type); | |||
| *size = abstract::TypeIdSize(args.src_data_type); | |||
| if (*size < 1) { | |||
| MS_LOG(ERROR) << "Illegal dtype."; | |||
| return false; | |||
| } | |||
| *total_size = ShapeSize(args.device_shape) * (*size); | |||
| *total_size = abstract::ShapeSize(args.device_shape) * (*size); | |||
| if (*total_size != args.device_size) { | |||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << *total_size << ", device_size:" << args.device_size; | |||
| return false; | |||
| @@ -516,7 +498,7 @@ bool TransFormat(const FormatArgs &args, void *result) { | |||
| {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, | |||
| {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}}; | |||
| MS_LOG(DEBUG) << "Start trans format."; | |||
| if (TypeIdSize(args.src_data_type) < 1) { | |||
| if (abstract::TypeIdSize(args.src_data_type) < 1) { | |||
| MS_LOG(ERROR) << "Invalid datatype.."; | |||
| return false; | |||
| } | |||
| @@ -538,7 +520,7 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { | |||
| {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, | |||
| {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}}; | |||
| MS_LOG(DEBUG) << "Start trans format."; | |||
| if (TypeIdSize(args.src_data_type) < 1) { | |||
| if (abstract::TypeIdSize(args.src_data_type) < 1) { | |||
| MS_LOG(ERROR) << "Invalid datatype.."; | |||
| return false; | |||
| } | |||
| @@ -624,7 +606,7 @@ bool NchwToFracZ(const FormatArgs &args, void *result) { | |||
| MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | |||
| return false; | |||
| } | |||
| auto size = TypeIdSize(args.src_data_type); | |||
| auto size = abstract::TypeIdSize(args.src_data_type); | |||
| if (size < 1) { | |||
| MS_LOG(ERROR) << "Illegal dtype."; | |||
| return false; | |||
| @@ -685,12 +667,12 @@ bool FracZToNchw(const FormatArgs &args, void *result) { | |||
| MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | |||
| return false; | |||
| } | |||
| auto size = TypeIdSize(args.src_data_type); | |||
| auto size = abstract::TypeIdSize(args.src_data_type); | |||
| if (size < 1) { | |||
| MS_LOG(ERROR) << "Illegal dtype."; | |||
| return false; | |||
| } | |||
| auto total_size = ShapeSize(args.device_shape) * size; | |||
| auto total_size = abstract::ShapeSize(args.device_shape) * size; | |||
| if (total_size != args.device_size) { | |||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; | |||
| return false; | |||
| @@ -828,13 +810,13 @@ bool NchwToFracNz(const FormatArgs &args, void *result) { | |||
| MS_LOG(ERROR) << "Invalid shape size."; | |||
| return false; | |||
| } | |||
| auto size = TypeIdSize(args.src_data_type); | |||
| auto size = abstract::TypeIdSize(args.src_data_type); | |||
| if (size < 1) { | |||
| MS_LOG(ERROR) << "Illegal dtype"; | |||
| return false; | |||
| } | |||
| auto dst_size = ShapeSize(args.device_shape) * size; | |||
| auto dst_size = abstract::ShapeSize(args.device_shape) * size; | |||
| if (dst_size != args.device_size) { | |||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size; | |||
| return false; | |||
| @@ -890,13 +872,13 @@ bool FracNzToNchw(const FormatArgs &args, void *result) { | |||
| MS_LOG(ERROR) << "Invalid shape size."; | |||
| return false; | |||
| } | |||
| auto size = TypeIdSize(args.src_data_type); | |||
| auto size = abstract::TypeIdSize(args.src_data_type); | |||
| if (size < 1) { | |||
| MS_LOG(ERROR) << "Illegal dtype"; | |||
| return false; | |||
| } | |||
| auto dst_size = ShapeSize(args.device_shape) * size; | |||
| auto dst_size = abstract::ShapeSize(args.device_shape) * size; | |||
| if (dst_size != args.device_size) { | |||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size; | |||
| return false; | |||
| @@ -947,12 +929,12 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) { | |||
| MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | |||
| return false; | |||
| } | |||
| auto size = TypeIdSize(args.src_data_type); | |||
| auto size = abstract::TypeIdSize(args.src_data_type); | |||
| if (size < 1) { | |||
| MS_LOG(ERROR) << "Illegal dtype."; | |||
| return false; | |||
| } | |||
| auto total_size = ShapeSize(args.device_shape) * size; | |||
| auto total_size = abstract::ShapeSize(args.device_shape) * size; | |||
| if (total_size != args.device_size) { | |||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; | |||
| return false; | |||
| @@ -1005,12 +987,12 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) { | |||
| MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; | |||
| return false; | |||
| } | |||
| auto size = TypeIdSize(args.src_data_type); | |||
| auto size = abstract::TypeIdSize(args.src_data_type); | |||
| if (size < 1) { | |||
| MS_LOG(ERROR) << "Illegal dtype."; | |||
| return false; | |||
| } | |||
| auto total_size = ShapeSize(args.device_shape) * size; | |||
| auto total_size = abstract::ShapeSize(args.device_shape) * size; | |||
| if (total_size != args.device_size) { | |||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; | |||
| return false; | |||
| @@ -48,8 +48,6 @@ struct FormatArgs { | |||
| TypeId src_data_type; | |||
| }; | |||
| size_t TypeIdSize(const TypeId data_type); | |||
| size_t ShapeSize(const std::vector<size_t> &shape); | |||
| size_t CubeSizeByType(const TypeId data_type); | |||
| std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis = {}); | |||
| @@ -26,6 +26,7 @@ | |||
| #include "runtime/device/convert_tensor_utils.h" | |||
| #include "ir/dtype/type.h" | |||
| #include "ir/tensor.h" | |||
| #include "abstract/utils.h" | |||
| #include "backend/kernel_compiler/tbe/tbe_kernel_build.h" | |||
| #include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" | |||
| #include "utils/utils.h" | |||
| @@ -298,7 +299,7 @@ bool AscendDeviceAddress::SyncDeviceToHost(const ShapeVector &shape, size_t size | |||
| } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { | |||
| sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_); | |||
| } else { | |||
| auto shape_size = trans::ShapeSize(host_shape); | |||
| auto shape_size = abstract::ShapeSize(host_shape); | |||
| auto host = std::vector<uint8_t>(size_); | |||
| SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); | |||
| const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size_}; | |||
| @@ -413,11 +414,11 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const | |||
| MS_EXCEPTION_IF_NULL(kernel_mod_ptr); | |||
| auto host_size = size; | |||
| if (type_id_ != type) { | |||
| auto device_dtype_size = trans::TypeIdSize(type_id_); | |||
| auto device_dtype_size = abstract::TypeIdSize(type_id_); | |||
| if (device_dtype_size < 1) { | |||
| MS_LOG(ERROR) << "Illegal dtype."; | |||
| } | |||
| auto shape_size = trans::ShapeSize(host_shape); | |||
| auto shape_size = abstract::ShapeSize(host_shape); | |||
| size = device_dtype_size * shape_size; | |||
| } | |||
| size = GetCommonAlignSize(size); | |||
| @@ -431,7 +432,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const | |||
| } else { | |||
| auto host = std::vector<uint8_t>(size); | |||
| SyncMemory(host.data(), output_address->GetPtr(), size, RT_MEMCPY_DEVICE_TO_HOST); | |||
| auto shape_size = trans::ShapeSize(host_shape); | |||
| auto shape_size = abstract::ShapeSize(host_shape); | |||
| const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, host_size}; | |||
| sync_ok = trans::TransDataType(type_args, host_ptr); | |||
| if (!sync_ok) { | |||
| @@ -500,7 +501,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh | |||
| MS_LOG(ERROR) << "Trans format failed."; | |||
| return false; | |||
| } | |||
| auto shape_size = trans::ShapeSize(host_shape); | |||
| auto shape_size = abstract::ShapeSize(host_shape); | |||
| const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size}; | |||
| sync_ok = trans::TransDataType(type_args, host_ptr); | |||
| if (!sync_ok) { | |||
| @@ -537,7 +538,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size | |||
| } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { | |||
| sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size); | |||
| } else { | |||
| auto shape_size = trans::ShapeSize(host_shape); | |||
| auto shape_size = abstract::ShapeSize(host_shape); | |||
| const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size}; | |||
| auto host_tmp = std::vector<uint8_t>(size_); | |||
| sync_ok = trans::TransDataType(type_args, host_tmp.data()); | |||
| @@ -581,7 +582,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh | |||
| device_shape = trans::TransShapeToDevice(host_shape, format_); | |||
| } | |||
| if (type_id_ != type) { | |||
| auto shape_size = trans::ShapeSize(host_shape); | |||
| auto shape_size = abstract::ShapeSize(host_shape); | |||
| const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size}; | |||
| auto host_tmp = std::vector<uint8_t>(size_); | |||
| sync_ok = trans::TransDataType(type_args, host_tmp.data()); | |||
| @@ -22,6 +22,8 @@ | |||
| #include <exception> | |||
| #include <algorithm> | |||
| #include <thread> | |||
| #include <stack> | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include "debug/data_dump/e2e_dump_util.h" | |||
| #include "runtime/device/ascend/ascend_device_address.h" | |||
| #include "runtime/device/cpu/mpi/mpi_interface.h" | |||
| @@ -39,6 +41,7 @@ | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "runtime/device/ascend/profiling/profiling_utils.h" | |||
| #include "backend/kernel_compiler/tbe/tbe_utils.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "runtime/device/ascend/ascend_memory_manager.h" | |||
| #include "debug/tensor_load.h" | |||
| #include "debug/data_dump/dump_json_parser.h" | |||
| @@ -110,6 +113,34 @@ std::string GetRankId() { | |||
| } | |||
| return rank_id_str; | |||
| } | |||
| void InferShapeForNopNode(AnfNodePtr *input_node) { | |||
| MS_EXCEPTION_IF_NULL(*input_node); | |||
| if (!opt::IsNopNode(*input_node)) { | |||
| MS_LOG(INFO) << "Input node is not a nop node, no need infer."; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Infer shape for nop node."; | |||
| std::stack<AnfNodePtr> nop_road; | |||
| nop_road.push(*input_node); | |||
| while (true) { | |||
| auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(*input_node, 0); | |||
| auto in_node = input_node_with_idx.first; | |||
| MS_EXCEPTION_IF_NULL(in_node); | |||
| if (opt::IsNopNode(in_node)) { | |||
| nop_road.push(in_node); | |||
| *input_node = in_node; | |||
| } else { | |||
| break; | |||
| } | |||
| } | |||
| while (!nop_road.empty()) { | |||
| auto nop_node = nop_road.top(); | |||
| AnfAlgo::InferShape(nop_node->cast<CNodePtr>()); | |||
| nop_road.pop(); | |||
| } | |||
| } | |||
| } // namespace | |||
| std::vector<rtExceptionInfo> AscendKernelRuntime::exception_infoes_; | |||
| @@ -633,6 +664,15 @@ bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph *grap | |||
| } | |||
| if (dynamic_kernel->is_dynamic_shape()) { | |||
| auto kernel_node = dynamic_kernel->kernel_node(); | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| auto input_size = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| for (size_t i = 0; i < input_size; i++) { | |||
| auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(kernel_node, i); | |||
| auto input_node = input_node_with_index.first; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| InferShapeForNopNode(&input_node); | |||
| } | |||
| dynamic_kernel->InferShape(); | |||
| dynamic_kernel->UpdateArgs(); | |||
| } | |||
| @@ -48,6 +48,7 @@ class DynamicKernel { | |||
| virtual void Initialize(); | |||
| std::string GetKernelName() { return cnode_ptr_->fullname_with_scope(); } | |||
| int GetKernelType(); | |||
| CNodePtr kernel_node() const { return cnode_ptr_; } | |||
| protected: | |||
| void RebuildDependTensor(); | |||