Merge pull request !260 from lianliguang/refactor-padding-strategytags/v0.2.0-alpha
| @@ -20,6 +20,8 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include "./securec.h" | #include "./securec.h" | ||||
| #include "common/utils.h" | #include "common/utils.h" | ||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "kernel/kernel.h" | |||||
| #include "device/convert_tensor_utils.h" | #include "device/convert_tensor_utils.h" | ||||
| #include "utils/convert_utils.h" | #include "utils/convert_utils.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| @@ -27,6 +29,33 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace trans { | namespace trans { | ||||
| namespace { | |||||
| std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) { | |||||
| std::vector<size_t> shape_4d(4, 1); | |||||
| switch (shape.size()) { | |||||
| case 0: | |||||
| return shape_4d; | |||||
| case 1: | |||||
| shape_4d[1] = shape[0]; | |||||
| break; | |||||
| case 2: | |||||
| shape_4d[1] = shape[0]; | |||||
| shape_4d[2] = shape[1]; | |||||
| break; | |||||
| case 3: | |||||
| shape_4d[1] = shape[0]; | |||||
| shape_4d[2] = shape[1]; | |||||
| shape_4d[3] = shape[2]; | |||||
| break; | |||||
| case 4: | |||||
| std::copy(shape.begin(), shape.end(), shape_4d.begin()); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size(); | |||||
| } | |||||
| return shape_4d; | |||||
| } | |||||
| } // namespace | |||||
| const size_t kNchwDims = 4; | const size_t kNchwDims = 4; | ||||
| const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1}, | const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1}, | ||||
| {kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8}, | {kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8}, | ||||
| @@ -154,38 +183,64 @@ size_t TypeIdSize(const TypeId data_type) { | |||||
| return unsupported_type_error; | return unsupported_type_error; | ||||
| } | } | ||||
| std::vector<size_t> TransShapeTo4d(const std::vector<size_t> &shape) { | |||||
| bool IsNeedPadding(const std::string &format, const size_t shape_size) { | |||||
| if (shape_size == 0) { | |||||
| return false; | |||||
| } | |||||
| if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) { | |||||
| return false; | |||||
| } else if (shape_size < 4) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) { | |||||
| std::vector<int> shape; | |||||
| std::vector<size_t> host_shape; | |||||
| if (node->isa<ValueNode>()) { | |||||
| auto value_node = node->cast<ValueNodePtr>(); | |||||
| auto node_value = value_node->value(); | |||||
| auto tensor = node_value->cast<tensor::TensorPtr>(); | |||||
| if (tensor == nullptr) { | |||||
| MS_LOG(EXCEPTION) << " the node[ " << node->DebugString() << "]'s cannot convert "; | |||||
| } | |||||
| shape = tensor->shape(); | |||||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); | |||||
| if (host_shape.empty()) { | |||||
| host_shape.push_back(1); | |||||
| } | |||||
| } else { | |||||
| host_shape = AnfAlgo::GetOutputInferShape(node, index); | |||||
| } | |||||
| if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, 0), host_shape.size())) { | |||||
| host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, 0)); | |||||
| } | |||||
| std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt); | |||||
| return shape; | |||||
| } | |||||
| std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<kernel::Axis> &padding_axis) { | |||||
| if (padding_axis.empty() || shape.size() != padding_axis.size()) { | |||||
| return PaddingShapeTo4dByDefault(shape); | |||||
| } | |||||
| std::vector<size_t> shape_4d(4, 1); | std::vector<size_t> shape_4d(4, 1); | ||||
| switch (shape.size()) { | |||||
| case 0: | |||||
| break; | |||||
| case 1: | |||||
| shape_4d[1] = shape[0]; | |||||
| break; | |||||
| case 2: | |||||
| shape_4d[0] = shape[0]; | |||||
| shape_4d[1] = shape[1]; | |||||
| break; | |||||
| case 3: | |||||
| MS_LOG(EXCEPTION) << "Unexpected shape size = 3,it should has a default format"; | |||||
| case 4: | |||||
| for (size_t i = 0; i < 4; ++i) { | |||||
| shape_4d[i] = shape[i]; | |||||
| } | |||||
| break; | |||||
| default: | |||||
| MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size(); | |||||
| for (size_t index = 0; index < padding_axis.size(); index++) { | |||||
| shape_4d[padding_axis[index]] = shape[index]; | |||||
| } | } | ||||
| return shape_4d; | return shape_4d; | ||||
| } | } | ||||
| std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) { | std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) { | ||||
| if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { | |||||
| return shape; | |||||
| } | |||||
| auto temp_shape = shape; | |||||
| std::vector<size_t> device_shape; | std::vector<size_t> device_shape; | ||||
| if (format == kOpFormat_FRAC_NZ) { | if (format == kOpFormat_FRAC_NZ) { | ||||
| if (shape.size() < 2) { | if (shape.size() < 2) { | ||||
| MS_EXCEPTION(NotSupportError) << "Format " << format << " is not support shape " << shape.size(); | |||||
| } | |||||
| if (shape.size() > 2) { | |||||
| MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size(); | |||||
| } else { | |||||
| (void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape)); | (void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape)); | ||||
| } | } | ||||
| auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1; | auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1; | ||||
| @@ -197,35 +252,36 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s | |||||
| return device_shape; | return device_shape; | ||||
| } | } | ||||
| if (shape.size() != 4) { | if (shape.size() != 4) { | ||||
| MS_LOG(EXCEPTION) << "shape_4d size should be 4"; | |||||
| MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; | |||||
| temp_shape = PaddingShapeTo4dByDefault(shape); | |||||
| } | } | ||||
| if (format == kOpFormat_NC1HWC0) { | if (format == kOpFormat_NC1HWC0) { | ||||
| size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; | |||||
| size_t C1 = (temp_shape[1] + kCubeSize - 1) / kCubeSize; | |||||
| size_t C0 = kCubeSize; | size_t C0 = kCubeSize; | ||||
| device_shape.push_back(shape[0]); | |||||
| device_shape.push_back(temp_shape[0]); | |||||
| device_shape.push_back(C1); | device_shape.push_back(C1); | ||||
| device_shape.push_back(shape[2]); | |||||
| device_shape.push_back(shape[3]); | |||||
| device_shape.push_back(temp_shape[2]); | |||||
| device_shape.push_back(temp_shape[3]); | |||||
| device_shape.push_back(C0); | device_shape.push_back(C0); | ||||
| return device_shape; | return device_shape; | ||||
| } else if (format == kOpFormat_FRAC_Z) { | } else if (format == kOpFormat_FRAC_Z) { | ||||
| size_t cout16 = ((shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize; | |||||
| size_t cin16 = ((shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize; | |||||
| device_shape.push_back(shape[2] * shape[3] * cin16 / kCubeSize); | |||||
| size_t cout16 = ((temp_shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize; | |||||
| size_t cin16 = ((temp_shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize; | |||||
| device_shape.push_back(temp_shape[2] * temp_shape[3] * cin16 / kCubeSize); | |||||
| device_shape.push_back(cout16 / kCubeSize); | device_shape.push_back(cout16 / kCubeSize); | ||||
| device_shape.push_back(kCubeSize); | device_shape.push_back(kCubeSize); | ||||
| device_shape.push_back(kCubeSize); | device_shape.push_back(kCubeSize); | ||||
| return device_shape; | return device_shape; | ||||
| } else if (format == kOpFormat_NHWC) { | } else if (format == kOpFormat_NHWC) { | ||||
| device_shape.push_back(shape[0]); | |||||
| device_shape.push_back(shape[2]); | |||||
| device_shape.push_back(shape[3]); | |||||
| device_shape.push_back(shape[1]); | |||||
| device_shape.push_back(temp_shape[0]); | |||||
| device_shape.push_back(temp_shape[2]); | |||||
| device_shape.push_back(temp_shape[3]); | |||||
| device_shape.push_back(temp_shape[1]); | |||||
| return device_shape; | return device_shape; | ||||
| } else if (format == kOpFormat_NCHW) { | |||||
| return shape; | |||||
| } else if (format == kOpFormat_HWCN) { | } else if (format == kOpFormat_HWCN) { | ||||
| return {shape[2], shape[3], shape[1], shape[0]}; | |||||
| return {temp_shape[2], temp_shape[3], temp_shape[1], temp_shape[0]}; | |||||
| } else if (format == kOpFormat_NCHW) { | |||||
| return temp_shape; | |||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]"; | MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]"; | ||||
| } | } | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "ir/dtype.h" | #include "ir/dtype.h" | ||||
| #include "kernel/kernel.h" | |||||
| #include "ir/dtype/type.h" | #include "ir/dtype/type.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -49,7 +50,10 @@ size_t TypeIdSize(const TypeId data_type); | |||||
| size_t ShapeSize(const std::vector<size_t> &shape); | size_t ShapeSize(const std::vector<size_t> &shape); | ||||
| size_t CubeSizeByType(const TypeId data_type); | size_t CubeSizeByType(const TypeId data_type); | ||||
| std::vector<size_t> TransShapeTo4d(const std::vector<size_t> &shape); | |||||
| std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, | |||||
| const std::vector<kernel::Axis> &padding_axis = {}); | |||||
| std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index); | |||||
| bool IsNeedPadding(const std::string &format, const size_t shape_size); | |||||
| std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format); | std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format); | ||||
| bool TransDataType(const TypeIdArgs &args, void *result); | bool TransDataType(const TypeIdArgs &args, void *result); | ||||
| bool TransFormat(const FormatArgs &args, void *result); | bool TransFormat(const FormatArgs &args, void *result); | ||||
| @@ -141,7 +141,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int | |||||
| if (format_ == kOpFormat_FRAC_NZ) { | if (format_ == kOpFormat_FRAC_NZ) { | ||||
| device_shape = trans::TransShapeToDevice(host_shape, format_); | device_shape = trans::TransShapeToDevice(host_shape, format_); | ||||
| } else { | } else { | ||||
| host_shape = trans::TransShapeTo4d(host_shape); | |||||
| host_shape = trans::PaddingShapeTo4d(host_shape); | |||||
| device_shape = trans::TransShapeToDevice(host_shape, format_); | device_shape = trans::TransShapeToDevice(host_shape, format_); | ||||
| } | } | ||||
| if (type_id_ != type) { | if (type_id_ != type) { | ||||
| @@ -224,7 +224,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int | |||||
| if (format_ == kOpFormat_FRAC_NZ) { | if (format_ == kOpFormat_FRAC_NZ) { | ||||
| device_shape = trans::TransShapeToDevice(host_shape, format_); | device_shape = trans::TransShapeToDevice(host_shape, format_); | ||||
| } else { | } else { | ||||
| host_shape = trans::TransShapeTo4d(host_shape); | |||||
| host_shape = trans::PaddingShapeTo4d(host_shape); | |||||
| device_shape = trans::TransShapeToDevice(host_shape, format_); | device_shape = trans::TransShapeToDevice(host_shape, format_); | ||||
| } | } | ||||
| if (type_id_ != type) { | if (type_id_ != type) { | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "device/ascend/profiling/profiling_manager.h" | #include "device/ascend/profiling/profiling_manager.h" | ||||
| #include "hccl/hcom.h" | #include "hccl/hcom.h" | ||||
| #include "common/trans.h" | |||||
| #include "runtime/context.h" | #include "runtime/context.h" | ||||
| #include "device/ascend/ascend_stream_assign.h" | #include "device/ascend/ascend_stream_assign.h" | ||||
| #include "device/ascend/ascend_memory_pool.h" | #include "device/ascend/ascend_memory_pool.h" | ||||
| @@ -150,7 +151,7 @@ void DumpOutput(mindspore::session::KernelGraph *graph, const string &dump_path, | |||||
| auto output_size = AnfAlgo::GetOutputTensorNum(node); | auto output_size = AnfAlgo::GetOutputTensorNum(node); | ||||
| for (size_t j = 0; j < output_size; ++j) { | for (size_t j = 0; j < output_size; ++j) { | ||||
| auto addr = AnfAlgo::GetOutputAddr(node, j); | auto addr = AnfAlgo::GetOutputAddr(node, j); | ||||
| auto shape = AnfAlgo::GetOutputInferShape(node, j); | |||||
| auto shape = trans::GetRuntimePaddingShape(node, j); | |||||
| auto type = AnfAlgo::GetOutputInferDataType(node, j); | auto type = AnfAlgo::GetOutputInferDataType(node, j); | ||||
| auto format = kOpFormat_DEFAULT; | auto format = kOpFormat_DEFAULT; | ||||
| string filepath = dump_path + '/' + kernel_name + '_' + "output_" + std::to_string(j); | string filepath = dump_path + '/' + kernel_name + '_' + "output_" + std::to_string(j); | ||||
| @@ -181,7 +182,7 @@ void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_p | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX); | auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX); | ||||
| auto shape = AnfAlgo::GetOutputInferShape(item, PRAMATER_OUTPUT_INDEX); | |||||
| auto shape = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX); | |||||
| auto type = AnfAlgo::GetOutputInferDataType(item, PRAMATER_OUTPUT_INDEX); | auto type = AnfAlgo::GetOutputInferDataType(item, PRAMATER_OUTPUT_INDEX); | ||||
| auto format = kOpFormat_DEFAULT; | auto format = kOpFormat_DEFAULT; | ||||
| string filepath = dump_path + '/' + parameter_name + '_' + "output_0"; | string filepath = dump_path + '/' + parameter_name + '_' + "output_0"; | ||||
| @@ -184,7 +184,7 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons | |||||
| } | } | ||||
| if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { | if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { | ||||
| if (AnfAlgo::IsFeatureMapInput(kernel_node, input_index) && | if (AnfAlgo::IsFeatureMapInput(kernel_node, input_index) && | ||||
| kSpecialFormatSet.find(kernel_build_info.GetInputFormat(input_index)) != kSpecialFormatSet.end()) { | |||||
| kNeedTransFormatSet.find(kernel_build_info.GetInputFormat(input_index)) != kNeedTransFormatSet.end()) { | |||||
| (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT]++; | (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT]++; | ||||
| } | } | ||||
| (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++; | (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++; | ||||
| @@ -210,19 +210,22 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons | |||||
| (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++; | (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++; | ||||
| } | } | ||||
| } | } | ||||
| } // namespace | |||||
| } | |||||
| void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { | void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | ||||
| auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); | auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); | ||||
| MS_EXCEPTION_IF_NULL(input_kernel_node); | MS_EXCEPTION_IF_NULL(input_kernel_node); | ||||
| if (AnfAlgo::IsFeatureMapInput(kernel_node, input_index)) { | |||||
| continue; | |||||
| } | |||||
| auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); | auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); | ||||
| MS_EXCEPTION_IF_NULL(input_with_index.first); | MS_EXCEPTION_IF_NULL(input_with_index.first); | ||||
| auto real_input_node = input_with_index.first; | auto real_input_node = input_with_index.first; | ||||
| if (real_input_node->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) { | |||||
| continue; | |||||
| } | |||||
| std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder = | std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder = | ||||
| std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | ||||
| // we set special device info of a input tensor. | // we set special device info of a input tensor. | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "common/trans.h" | |||||
| #include "utils/config_manager.h" | #include "utils/config_manager.h" | ||||
| #include "common/utils.h" | #include "common/utils.h" | ||||
| #include "kernel/kernel_build_info.h" | #include "kernel/kernel_build_info.h" | ||||
| @@ -391,7 +392,8 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::Context> &c | |||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); | auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); | ||||
| MS_EXCEPTION_IF_NULL(device_address); | MS_EXCEPTION_IF_NULL(device_address); | ||||
| tensor->set_device_address(device_address); | tensor->set_device_address(device_address); | ||||
| if (!device_address->SyncHostToDevice(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||||
| if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), | |||||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||||
| tensor->data_c(false))) { | tensor->data_c(false))) { | ||||
| MS_LOG(INFO) << "SyncHostToDevice failed."; | MS_LOG(INFO) << "SyncHostToDevice failed."; | ||||
| return false; | return false; | ||||
| @@ -31,6 +31,7 @@ class KernelInfo { | |||||
| public: | public: | ||||
| KernelInfo() { | KernelInfo() { | ||||
| kernel_mod_ = nullptr; | kernel_mod_ = nullptr; | ||||
| is_feature_map_ = false; | |||||
| select_kernel_build_info_ = nullptr; | select_kernel_build_info_ = nullptr; | ||||
| output_address_list_ = {}; | output_address_list_ = {}; | ||||
| workspace_address_list_ = {}; | workspace_address_list_ = {}; | ||||
| @@ -45,6 +46,7 @@ class KernelInfo { | |||||
| void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | ||||
| select_kernel_build_info_ = select_kernel_build_info; | select_kernel_build_info_ = select_kernel_build_info; | ||||
| } | } | ||||
| void SetFeatureMapFlag(bool flag) { is_feature_map_ = flag; } | |||||
| const DeviceAddress *GetOutputAddr(size_t index) const; | const DeviceAddress *GetOutputAddr(size_t index) const; | ||||
| DeviceAddressPtr GetMutableOutputAddr(size_t index) const; | DeviceAddressPtr GetMutableOutputAddr(size_t index) const; | ||||
| bool OutputAddrExist(size_t index) const; | bool OutputAddrExist(size_t index) const; | ||||
| @@ -63,8 +65,10 @@ class KernelInfo { | |||||
| void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } | void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } | ||||
| uint32_t graph_id() const { return graph_id_; } | uint32_t graph_id() const { return graph_id_; } | ||||
| bool operator==(const KernelInfo &other) const; | bool operator==(const KernelInfo &other) const; | ||||
| bool is_feature_map() const { return is_feature_map_; } | |||||
| private: | private: | ||||
| bool is_feature_map_; | |||||
| kernel::KernelBuildInfoPtr select_kernel_build_info_; | kernel::KernelBuildInfoPtr select_kernel_build_info_; | ||||
| std::vector<std::shared_ptr<DeviceAddress>> output_address_list_; | std::vector<std::shared_ptr<DeviceAddress>> output_address_list_; | ||||
| std::vector<std::shared_ptr<DeviceAddress>> workspace_address_list_; | std::vector<std::shared_ptr<DeviceAddress>> workspace_address_list_; | ||||
| @@ -105,7 +105,7 @@ size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &nod | |||||
| std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index); | std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index); | ||||
| auto format = AnfAlgo::GetOutputFormat(node, output_index); | auto format = AnfAlgo::GetOutputFormat(node, output_index); | ||||
| if (shape.empty() && format != kOpFormat_DEFAULT) { | if (shape.empty() && format != kOpFormat_DEFAULT) { | ||||
| shape = trans::TransShapeTo4d(shape); | |||||
| shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index)); | |||||
| shape = trans::TransShapeToDevice(shape, format); | shape = trans::TransShapeToDevice(shape, format); | ||||
| } | } | ||||
| // scalar's output shape is a empty vector | // scalar's output shape is a empty vector | ||||
| @@ -401,8 +401,9 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const | |||||
| auto address = CreateDeviceAddress(ptr, node_size, AnfAlgo::GetOutputFormat(value_node, output_idx), output_type_id); | auto address = CreateDeviceAddress(ptr, node_size, AnfAlgo::GetOutputFormat(value_node, output_idx), output_type_id); | ||||
| MS_EXCEPTION_IF_NULL(address); | MS_EXCEPTION_IF_NULL(address); | ||||
| AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); | AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); | ||||
| if (!address->SyncHostToDevice(tensor->shape(), tensor_size, tensor->data_type(), tensor->data_c(false))) { | |||||
| MS_EXCEPTION(NotExistsError) << "kValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is" | |||||
| if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(), | |||||
| tensor->data_c(false))) { | |||||
| MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is" | |||||
| << AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is " | << AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is " | ||||
| << AnfAlgo::GetOutputInferDataType(value_node, output_idx); | << AnfAlgo::GetOutputInferDataType(value_node, output_idx); | ||||
| } | } | ||||
| @@ -421,19 +422,6 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(node_value); | MS_EXCEPTION_IF_NULL(node_value); | ||||
| if (node_value->isa<Tensor>()) { | if (node_value->isa<Tensor>()) { | ||||
| AssignValueNodeTensor(value_node, node_value, 0); | AssignValueNodeTensor(value_node, node_value, 0); | ||||
| } else if (node_value->isa<ValueTuple>()) { | |||||
| auto value_tuple = node_value->cast<ValueTuplePtr>(); | |||||
| if (value_tuple == nullptr) { | |||||
| MS_LOG(WARNING) << "value_tuple is null"; | |||||
| continue; | |||||
| } | |||||
| size_t i = 0; | |||||
| auto value_list = value_tuple->value(); | |||||
| for (auto value_ptr : value_list) { | |||||
| if (value_ptr->isa<Tensor>()) { | |||||
| AssignValueNodeTensor(value_node, value_ptr, i++); | |||||
| } | |||||
| } | |||||
| } else if (node_value->isa<StringImm>()) { | } else if (node_value->isa<StringImm>()) { | ||||
| auto value = GetValue<std::string>(node_value); | auto value = GetValue<std::string>(node_value); | ||||
| size_t tensor_size = value.size(); | size_t tensor_size = value.size(); | ||||
| @@ -59,30 +59,20 @@ size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); } | |||||
| size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); } | size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); } | ||||
| bool KernelBuildInfo::GetInputReshapeType(size_t input_index, std::vector<Axis> *reshape_type) const { | |||||
| MS_EXCEPTION_IF_NULL(reshape_type); | |||||
| reshape_type->clear(); | |||||
| std::vector<Axis> KernelBuildInfo::GetInputReshapeType(size_t input_index) const { | |||||
| if (input_index >= input_reshape_type_.size()) { | if (input_index >= input_reshape_type_.size()) { | ||||
| MS_LOG(WARNING) << "The index [" << input_index << "] is exceed the number of input node size " | |||||
| << input_reshape_type_.size(); | |||||
| return false; | |||||
| MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size " | |||||
| << input_reshape_type_.size(); | |||||
| } | } | ||||
| (void)std::copy(input_reshape_type_[input_index].begin(), input_reshape_type_[input_index].end(), | |||||
| std::inserter(*reshape_type, (*reshape_type).begin())); | |||||
| return true; | |||||
| return input_reshape_type_[input_index]; | |||||
| } | } | ||||
| bool KernelBuildInfo::GetOutputReshapeType(size_t output_index, std::vector<Axis> *reshape_type) const { | |||||
| MS_EXCEPTION_IF_NULL(reshape_type); | |||||
| reshape_type->clear(); | |||||
| std::vector<Axis> KernelBuildInfo::GetOutputReshapeType(size_t output_index) const { | |||||
| if (output_index >= output_reshape_type_.size()) { | if (output_index >= output_reshape_type_.size()) { | ||||
| MS_LOG(WARNING) << "The index [" << output_index << "] is exceed the number of output node dixr" | |||||
| << output_reshape_type_.size(); | |||||
| return false; | |||||
| MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of output node size " | |||||
| << output_reshape_type_.size(); | |||||
| } | } | ||||
| (void)std::copy(output_reshape_type_[output_index].begin(), output_reshape_type_[output_index].end(), | |||||
| std::inserter(*reshape_type, (*reshape_type).begin())); | |||||
| return true; | |||||
| return output_reshape_type_[output_index]; | |||||
| } | } | ||||
| std::string KernelBuildInfo::ToString() const { | std::string KernelBuildInfo::ToString() const { | ||||
| @@ -115,6 +105,10 @@ bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const { | |||||
| return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_); | return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_); | ||||
| } | } | ||||
| bool KernelBuildInfo::IsInputDefaultPadding() const { return output_reshape_type_.empty(); } | |||||
| bool KernelBuildInfo::IsOutputDefaultPadding() const { return input_reshape_type_.empty(); } | |||||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) { | void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_build_info_); | MS_EXCEPTION_IF_NULL(kernel_build_info_); | ||||
| kernel_build_info_->kernel_type_ = kernel_type; | kernel_build_info_->kernel_type_ = kernel_type; | ||||
| @@ -54,9 +54,13 @@ class KernelBuildInfo { | |||||
| TypeId GetOutputDeviceType(size_t output_index) const; | TypeId GetOutputDeviceType(size_t output_index) const; | ||||
| bool GetInputReshapeType(size_t input_index, std::vector<Axis> *reshape_type) const; | |||||
| std::vector<Axis> GetInputReshapeType(size_t input_index) const; | |||||
| bool GetOutputReshapeType(size_t input_index, std::vector<Axis> *reshape_type) const; | |||||
| bool IsInputDefaultPadding() const; | |||||
| bool IsOutputDefaultPadding() const; | |||||
| std::vector<Axis> GetOutputReshapeType(size_t input_index) const; | |||||
| std::vector<std::string> GetAllInputFormats() const; | std::vector<std::string> GetAllInputFormats() const; | ||||
| @@ -18,20 +18,21 @@ | |||||
| #include <set> | #include <set> | ||||
| #include "common/trans.h" | #include "common/trans.h" | ||||
| #include "common/utils.h" | #include "common/utils.h" | ||||
| #include "utils/utils.h" | |||||
| #include "device/kernel_info.h" | #include "device/kernel_info.h" | ||||
| #include "kernel/oplib/oplib.h" | #include "kernel/oplib/oplib.h" | ||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "session/kernel_graph.h" | #include "session/kernel_graph.h" | ||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "utils/utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; | using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; | ||||
| namespace { | namespace { | ||||
| kernel::KernelBuildInfoPtr CreateKernelBuildInfo(const std::string &input_format, const std::string &output_format, | |||||
| const AnfNodePtr &node, const kernel::KernelBuildInfo ori_build_info) { | |||||
| kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, | |||||
| const AnfNodePtr &node, | |||||
| const kernel::KernelBuildInfo ori_build_info) { | |||||
| KernelBuildInfoBuilder builder; | KernelBuildInfoBuilder builder; | ||||
| builder.SetInputsFormat({input_format}); | builder.SetInputsFormat({input_format}); | ||||
| builder.SetOutputsFormat({output_format}); | builder.SetOutputsFormat({output_format}); | ||||
| @@ -54,9 +55,11 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||||
| CNodePtr trans_node = func_graph->NewCNode(trans_inputs); | CNodePtr trans_node = func_graph->NewCNode(trans_inputs); | ||||
| MS_EXCEPTION_IF_NULL(trans_node); | MS_EXCEPTION_IF_NULL(trans_node); | ||||
| if (need_padding) { | if (need_padding) { | ||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, | |||||
| {trans::TransShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0))}, | |||||
| trans_node.get()); | |||||
| // if need padding we should set the transdata node's shape to the padding shape | |||||
| AnfAlgo::SetOutputInferTypeAndShape( | |||||
| {AnfAlgo::GetOutputInferDataType(input, 0)}, | |||||
| {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), AnfAlgo::GetOutputReshapeType(input, 0))}, | |||||
| trans_node.get()); | |||||
| } else { | } else { | ||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, | AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, | ||||
| {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); | {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); | ||||
| @@ -92,9 +95,11 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i | |||||
| AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index, | AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index, | ||||
| const KernelSelectPtr &kernel_select) { | const KernelSelectPtr &kernel_select) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| bool padding_flag = false; | |||||
| auto input_node = AnfAlgo::GetInputNode(node, index); | auto input_node = AnfAlgo::GetInputNode(node, index); | ||||
| if (input_node->isa<ValueNode>() || input_node->isa<Parameter>()) { | |||||
| auto node_with_index = AnfAlgo::VisitKernel(input_node, 0); | |||||
| MS_EXCEPTION_IF_NULL(node_with_index.first); | |||||
| auto real_input = node_with_index.first; | |||||
| if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) { | |||||
| input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select); | input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select); | ||||
| MS_EXCEPTION_IF_NULL(input_node); | MS_EXCEPTION_IF_NULL(input_node); | ||||
| AnfAlgo::SetNodeInput(node, input_node, index); | AnfAlgo::SetNodeInput(node, input_node, index); | ||||
| @@ -106,33 +111,11 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & | |||||
| std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index); | std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index); | ||||
| std::string origin_format = kOpFormat_DEFAULT; | std::string origin_format = kOpFormat_DEFAULT; | ||||
| std::string dest_format = AnfAlgo::GetInputFormat(node, index); | std::string dest_format = AnfAlgo::GetInputFormat(node, index); | ||||
| if (dest_format == kOpFormat_C1HWNCoC0) { | |||||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||||
| AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag, | |||||
| origin_format, dest_format, kTransDataOpName, true); | |||||
| MS_EXCEPTION_IF_NULL(replace_input); | |||||
| return replace_input; | |||||
| } | |||||
| if (dest_format == kOpFormat_NC1HWC0 && origin_shape.size() > 1) { | |||||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||||
| AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag, | |||||
| origin_format, dest_format, kTransDataOpName, true); | |||||
| MS_EXCEPTION_IF_NULL(replace_input); | |||||
| MS_LOG(DEBUG) << "Inserted Translate45, index: " << index; | |||||
| return replace_input; | |||||
| } else if (dest_format == kOpFormat_FRAC_NZ) { | |||||
| AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag, | |||||
| origin_format, dest_format, kTransDataOpName, true); | |||||
| MS_EXCEPTION_IF_NULL(replace_input); | |||||
| MS_LOG(DEBUG) << "inserted translate " << AnfAlgo::GetInputFormat(node, index) << " To default, index: " << index; | |||||
| return replace_input; | |||||
| } else if (dest_format == kOpFormat_FRAC_Z && !origin_shape.empty()) { | |||||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||||
| AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag, | |||||
| origin_format, dest_format, kTransDataOpName, true); | |||||
| MS_EXCEPTION_IF_NULL(replace_input); | |||||
| MS_LOG(DEBUG) << "Inserted Translate45, index: " << index; | |||||
| return replace_input; | |||||
| if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { | |||||
| MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) | |||||
| << " To DefaultFormat , index: " << index; | |||||
| return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, origin_format, dest_format, kTransDataOpName, | |||||
| true); | |||||
| } | } | ||||
| return input_node; | return input_node; | ||||
| } | } | ||||
| @@ -140,7 +123,6 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & | |||||
| AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const KernelSelectPtr &kernel_select) { | const KernelSelectPtr &kernel_select) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| bool padding_flag = false; | |||||
| std::string output_format; | std::string output_format; | ||||
| std::vector<size_t> origin_shape; | std::vector<size_t> origin_shape; | ||||
| if (!AnfAlgo::IsRealKernel(node)) { | if (!AnfAlgo::IsRealKernel(node)) { | ||||
| @@ -156,46 +138,14 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An | |||||
| } | } | ||||
| std::string origin_format = output_format; | std::string origin_format = output_format; | ||||
| std::string dest_format = kOpFormat_DEFAULT; | std::string dest_format = kOpFormat_DEFAULT; | ||||
| if (output_format == kOpFormat_C1HWNCoC0) { | |||||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||||
| AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format, | |||||
| dest_format, kTransDataOpName, false); | |||||
| MS_EXCEPTION_IF_NULL(replace_input); | |||||
| return replace_input; | |||||
| } | |||||
| if (output_format == kOpFormat_NC1HWC0 && origin_shape.size() > 1) { | |||||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||||
| AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format, | |||||
| dest_format, kTransDataOpName, false); | |||||
| MS_EXCEPTION_IF_NULL(replace_output); | |||||
| MS_LOG(DEBUG) << "Inserted Trans54"; | |||||
| return replace_output; | |||||
| } else if (output_format == kOpFormat_FRAC_NZ) { | |||||
| AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format, | |||||
| dest_format, kTransDataOpName, false); | |||||
| MS_EXCEPTION_IF_NULL(replace_output); | |||||
| MS_LOG(DEBUG) << "Inserted Translate " << output_format << " To default, index: 0"; | |||||
| return replace_output; | |||||
| } else if (output_format == kOpFormat_FRAC_Z && !origin_shape.empty()) { | |||||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||||
| AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format, | |||||
| dest_format, kTransDataOpName, false); | |||||
| MS_EXCEPTION_IF_NULL(replace_output); | |||||
| MS_LOG(DEBUG) << "Inserted Trans54"; | |||||
| return replace_output; | |||||
| if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { | |||||
| MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0"; | |||||
| return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, origin_format, dest_format, kTransDataOpName, | |||||
| false); | |||||
| } | } | ||||
| return node; | return node; | ||||
| } | } | ||||
| void GetTransDataInputFormat(const AnfNodePtr &node, size_t idx, std::string *input_format) { | |||||
| MS_EXCEPTION_IF_NULL(input_format); | |||||
| if (AnfAlgo::IsRealKernel(node)) { | |||||
| *input_format = AnfAlgo::GetOutputFormat(node, idx); | |||||
| } else { | |||||
| *input_format = AnfAlgo::GetPrevNodeOutputFormat(node, 0); | |||||
| } | |||||
| } | |||||
| AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const KernelSelectPtr &kernel_select) { | const KernelSelectPtr &kernel_select) { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| @@ -203,46 +153,17 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const | |||||
| std::vector<AnfNodePtr> make_tuple_inputs; | std::vector<AnfNodePtr> make_tuple_inputs; | ||||
| make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | ||||
| for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) { | for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) { | ||||
| bool padding_flag = false; | |||||
| std::string output_format; | |||||
| GetTransDataInputFormat(node, output_idx, &output_format); | |||||
| std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx); | |||||
| if (output_format == kOpFormat_NC1KHKWHWC0) { | if (output_format == kOpFormat_NC1KHKWHWC0) { | ||||
| MS_LOG(EXCEPTION) << "got the hw format" << output_format << " when insert the transdata node " | |||||
| MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node " | |||||
| << node->DebugString(); | << node->DebugString(); | ||||
| } | } | ||||
| auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); | auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); | ||||
| std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); | std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); | ||||
| std::string origin_format = output_format; | |||||
| std::string dest_format = kOpFormat_DEFAULT; | std::string dest_format = kOpFormat_DEFAULT; | ||||
| if (output_format == kOpFormat_C1HWNCoC0) { | |||||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||||
| AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag, | |||||
| origin_format, dest_format, kTransDataOpName, false); | |||||
| MS_EXCEPTION_IF_NULL(replace_input); | |||||
| return replace_input; | |||||
| } | |||||
| if (output_format == kOpFormat_NC1HWC0 && origin_shape.size() > 1) { | |||||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||||
| // Insert a 5to4 trans op. | |||||
| AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag, | |||||
| origin_format, dest_format, kTransDataOpName, false); | |||||
| MS_EXCEPTION_IF_NULL(replace_output); | |||||
| MS_LOG(DEBUG) << "Inserted Translate54"; | |||||
| make_tuple_inputs.push_back(replace_output); | |||||
| } else if (output_format == kOpFormat_FRAC_NZ) { | |||||
| AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag, | |||||
| origin_format, dest_format, kTransDataOpName, false); | |||||
| MS_EXCEPTION_IF_NULL(replace_output); | |||||
| MS_LOG(DEBUG) << "Inserted Translate " << output_format << " To default, index: " << output_idx; | |||||
| make_tuple_inputs.push_back(replace_output); | |||||
| } else if (output_format == kOpFormat_FRAC_Z && !origin_shape.empty()) { | |||||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||||
| AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag, | |||||
| origin_format, dest_format, kTransDataOpName, false); | |||||
| MS_EXCEPTION_IF_NULL(replace_output); | |||||
| MS_LOG(DEBUG) << "Inserted Translate54"; | |||||
| make_tuple_inputs.push_back(replace_output); | |||||
| if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { | |||||
| make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, output_format, | |||||
| dest_format, kTransDataOpName, false)); | |||||
| } else { | } else { | ||||
| // No need insert trans op. | // No need insert trans op. | ||||
| make_tuple_inputs.push_back(tuple_getitem); | make_tuple_inputs.push_back(tuple_getitem); | ||||
| @@ -253,16 +174,17 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const KernelSelectPtr &kernel_select, size_t insert_index, const bool padding_flag, | |||||
| const KernelSelectPtr &kernel_select, size_t insert_index, | |||||
| const std::string &origin_format, const std::string &dest_format, | const std::string &origin_format, const std::string &dest_format, | ||||
| const std::string &op_name, bool is_insert_input) { | const std::string &op_name, bool is_insert_input) { | ||||
| AnfNodePtr trans_node = nullptr; | AnfNodePtr trans_node = nullptr; | ||||
| AnfNodePtr input_node = nullptr; | |||||
| AnfNodePtr input_node = node; | |||||
| AnfNodePtr trans_data = nullptr; | AnfNodePtr trans_data = nullptr; | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (origin_format.empty() || dest_format.empty()) { | if (origin_format.empty() || dest_format.empty()) { | ||||
| MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format; | MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format; | ||||
| } | } | ||||
| // if insert transdata for input we need to change the input | |||||
| if (is_insert_input) { | if (is_insert_input) { | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; | MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; | ||||
| @@ -270,29 +192,34 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| input_node = AnfAlgo::GetInputNode(cnode, insert_index); | input_node = AnfAlgo::GetInputNode(cnode, insert_index); | ||||
| if (padding_flag) { | |||||
| auto padd_shape = trans::TransShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0)); | |||||
| auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padd_shape); | |||||
| trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, padding_flag, op_name); | |||||
| } else { | |||||
| trans_data = NewTransOpNode(func_graph, input_node, kernel_select, padding_flag, op_name); | |||||
| } | |||||
| } | |||||
| bool need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && | |||||
| op_name == kTransDataOpName); | |||||
| if (!need_padding) { | |||||
| // don't need padding insert transdata only | |||||
| trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name); | |||||
| trans_node = trans_data; | |||||
| } else if (is_insert_input) { | |||||
| // if need padding & is input need insert a transdata | |||||
| // reshape[padding shape] -> transdata[padding shape] -> node | |||||
| auto padding_shape = | |||||
| trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0)); | |||||
| auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); | |||||
| trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, op_name); | |||||
| trans_node = trans_data; | trans_node = trans_data; | ||||
| } else { | } else { | ||||
| input_node = node; | |||||
| trans_data = NewTransOpNode(func_graph, input_node, kernel_select, padding_flag, op_name); | |||||
| if (padding_flag) { | |||||
| auto reshape_node = | |||||
| CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); | |||||
| trans_node = reshape_node; | |||||
| } else { | |||||
| trans_node = trans_data; | |||||
| } | |||||
| // if need padding & is output need insert a transdata | |||||
| // node -> transdata[padding shape] -> reshape[ori_shape] | |||||
| trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name); | |||||
| auto reshape_node = | |||||
| CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); | |||||
| trans_node = reshape_node; | |||||
| } | } | ||||
| // refresh the transdata's format to ori format & dst format | |||||
| MS_EXCEPTION_IF_NULL(trans_data); | MS_EXCEPTION_IF_NULL(trans_data); | ||||
| MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); | MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); | ||||
| auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); | auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); | ||||
| auto kernel_build_info = CreateKernelBuildInfo(origin_format, dest_format, input_node, *trans_ori_build_info); | |||||
| auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, *trans_ori_build_info); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get()); | AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get()); | ||||
| return trans_node; | return trans_node; | ||||
| } | } | ||||
| @@ -376,7 +303,17 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod | |||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | ||||
| TypeId origin_type; | TypeId origin_type; | ||||
| auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); | auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); | ||||
| if (!AnfAlgo::IsFeatureMapInput(cnode, input_index)) { | |||||
| auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); | |||||
| auto is_weight_boundary = [](const AnfNodePtr &node) -> bool { | |||||
| if (node->isa<ValueNode>()) { | |||||
| return true; | |||||
| } else if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| }; | |||||
| auto real_input_node = kernel_with_index.first; | |||||
| if (is_weight_boundary(real_input_node)) { | |||||
| // weight | // weight | ||||
| origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index); | origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index); | ||||
| } else { | } else { | ||||
| @@ -48,7 +48,7 @@ class KernelQuery { | |||||
| using KernelQueryPtr = std::shared_ptr<KernelQuery>; | using KernelQueryPtr = std::shared_ptr<KernelQuery>; | ||||
| AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const KernelSelectPtr &kernel_select, size_t insert_index, bool padding_flag, | |||||
| const KernelSelectPtr &kernel_select, size_t insert_index, | |||||
| const std::string &origin_format, const std::string &dest_format, | const std::string &origin_format, const std::string &dest_format, | ||||
| const std::string &op_name, bool is_insert_input); | const std::string &op_name, bool is_insert_input); | ||||
| @@ -105,10 +105,8 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP | |||||
| // insert trans | // insert trans | ||||
| if (origin_format != cur_format) { | if (origin_format != cur_format) { | ||||
| auto kernel_select = std::make_shared<KernelSelect>(); | auto kernel_select = std::make_shared<KernelSelect>(); | ||||
| bool need_padding = | |||||
| (cur_format == kOpFormat_NC1HWC0 && AnfAlgo::GetOutputInferShape(final_node, 0).size() != kShape4dDims); | |||||
| final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, need_padding, cur_format, | |||||
| origin_format, kTransDataOpName, false); | |||||
| final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, cur_format, origin_format, | |||||
| kTransDataOpName, false); | |||||
| final_index = 0; | final_index = 0; | ||||
| MS_EXCEPTION_IF_NULL(final_node); | MS_EXCEPTION_IF_NULL(final_node); | ||||
| MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); | MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); | ||||
| @@ -1,99 +1,99 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/ir_fusion/transdata_split.h" | |||||
| #include <set> | |||||
| #include "pre_activate/ascend/ascend_helper.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "debug/anf_ir_dump.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| const std::set<std::pair<string, string>> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW}, | |||||
| {kOpFormat_NCHW, kOpFormat_C1HWNCoC0}, | |||||
| {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, | |||||
| {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; | |||||
| bool TransDataSplit::Run(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| bool changed = false; | |||||
| std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return()); | |||||
| for (auto &node : node_list) { | |||||
| if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { | |||||
| CheckCNodeInputSize(node->cast<CNodePtr>(), kBackendTransDataInputNum); | |||||
| if (IsFormatInvaild(node)) { | |||||
| changed = DoSplit(func_graph, node); | |||||
| } | |||||
| } | |||||
| } | |||||
| return changed; | |||||
| } | |||||
| bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto input_format = AnfAlgo::GetInputFormat(node, 0); | |||||
| auto output_format = AnfAlgo::GetOutputFormat(node, 0); | |||||
| auto format_pair = std::make_pair(input_format, output_format); | |||||
| return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end(); | |||||
| } | |||||
| // transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW) | |||||
| bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto input_node = node->cast<CNodePtr>()->input(1); | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| auto input_format = AnfAlgo::GetInputFormat(node, 0); | |||||
| auto output_format = AnfAlgo::GetOutputFormat(node, 0); | |||||
| AnfNodePtr new_transdata_node = nullptr; | |||||
| AnfNodePtr new_transpose_node = nullptr; | |||||
| AnfNodePtr new_replace_node = nullptr; | |||||
| // if output_format=default transdata need split transdata->transpose else transpose->transdata | |||||
| if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { | |||||
| // trans input_format to hwcn | |||||
| new_transdata_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, false, input_format, kOpFormat_HWCN, | |||||
| kTransDataOpName, true); | |||||
| // trans hwcn to default_format | |||||
| new_transpose_node = AddTransOpNodeToGraph(func_graph, new_transdata_node, kernel_select_, 0, false, kOpFormat_HWCN, | |||||
| output_format, prim::kPrimTranspose->name(), false); | |||||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{3, 2, 0, 1}), new_transpose_node); | |||||
| new_replace_node = new_transpose_node; | |||||
| } else { | |||||
| // trans default to hwcn | |||||
| new_transpose_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, false, input_format, kOpFormat_HWCN, | |||||
| prim::kPrimTranspose->name(), true); | |||||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node); | |||||
| // trans hwcn to output_format | |||||
| new_transdata_node = AddTransOpNodeToGraph(func_graph, new_transpose_node, kernel_select_, 0, false, kOpFormat_HWCN, | |||||
| output_format, kTransDataOpName, false); | |||||
| new_replace_node = new_transdata_node; | |||||
| } | |||||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| manager->AddFuncGraph(func_graph); | |||||
| if (!manager->Replace(node, new_replace_node)) { | |||||
| MS_LOG(EXCEPTION) << "manager replace node failed"; | |||||
| } | |||||
| MS_LOG(INFO) << "transdata node:" << cnode->DebugString() << "split success."; | |||||
| return true; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/ir_fusion/transdata_split.h" | |||||
| #include <set> | |||||
| #include "pre_activate/ascend/ascend_helper.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "debug/anf_ir_dump.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| const std::set<std::pair<string, string>> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW}, | |||||
| {kOpFormat_NCHW, kOpFormat_C1HWNCoC0}, | |||||
| {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, | |||||
| {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; | |||||
| bool TransDataSplit::Run(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| bool changed = false; | |||||
| std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return()); | |||||
| for (auto &node : node_list) { | |||||
| if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { | |||||
| CheckCNodeInputSize(node->cast<CNodePtr>(), kBackendTransDataInputNum); | |||||
| if (IsFormatInvaild(node)) { | |||||
| changed = DoSplit(func_graph, node); | |||||
| } | |||||
| } | |||||
| } | |||||
| return changed; | |||||
| } | |||||
| bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto input_format = AnfAlgo::GetInputFormat(node, 0); | |||||
| auto output_format = AnfAlgo::GetOutputFormat(node, 0); | |||||
| auto format_pair = std::make_pair(input_format, output_format); | |||||
| return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end(); | |||||
| } | |||||
| // transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW) | |||||
| bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto input_node = node->cast<CNodePtr>()->input(1); | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| auto input_format = AnfAlgo::GetInputFormat(node, 0); | |||||
| auto output_format = AnfAlgo::GetOutputFormat(node, 0); | |||||
| AnfNodePtr new_transdata_node = nullptr; | |||||
| AnfNodePtr new_transpose_node = nullptr; | |||||
| AnfNodePtr new_replace_node = nullptr; | |||||
| // if output_format=default transdata need split transdata->transpose else transpose->transdata | |||||
| if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { | |||||
| // trans input_format to hwcn | |||||
| new_transdata_node = | |||||
| AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN, kTransDataOpName, true); | |||||
| // trans hwcn to default_format | |||||
| new_transpose_node = AddTransOpNodeToGraph(func_graph, new_transdata_node, kernel_select_, 0, kOpFormat_HWCN, | |||||
| output_format, prim::kPrimTranspose->name(), false); | |||||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{3, 2, 0, 1}), new_transpose_node); | |||||
| new_replace_node = new_transpose_node; | |||||
| } else { | |||||
| // trans default to hwcn | |||||
| new_transpose_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN, | |||||
| prim::kPrimTranspose->name(), true); | |||||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node); | |||||
| // trans hwcn to output_format | |||||
| new_transdata_node = AddTransOpNodeToGraph(func_graph, new_transpose_node, kernel_select_, 0, kOpFormat_HWCN, | |||||
| output_format, kTransDataOpName, false); | |||||
| new_replace_node = new_transdata_node; | |||||
| } | |||||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| manager->AddFuncGraph(func_graph); | |||||
| if (!manager->Replace(node, new_replace_node)) { | |||||
| MS_LOG(EXCEPTION) << "Manager replace node failed"; | |||||
| } | |||||
| MS_LOG(INFO) << "Transdata node:" << cnode->DebugString() << "split success."; | |||||
| return true; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -291,6 +291,11 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) { | |||||
| std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) { | std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (output_idx > GetOutputTensorNum(node)) { | |||||
| MS_LOG(EXCEPTION) << "Output index:" << output_idx | |||||
| << " is out of the node output range :" << GetOutputTensorNum(node) << " #node [" | |||||
| << node->DebugString() << "]"; | |||||
| } | |||||
| auto kernel_info = node->kernel_info(); | auto kernel_info = node->kernel_info(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info); | MS_EXCEPTION_IF_NULL(kernel_info); | ||||
| auto build_info = kernel_info->select_kernel_build_info(); | auto build_info = kernel_info->select_kernel_build_info(); | ||||
| @@ -300,6 +305,11 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t | |||||
| std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) { | std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (input_idx > GetInputTensorNum(node)) { | |||||
| MS_LOG(EXCEPTION) << "Input index :" << input_idx | |||||
| << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node [" | |||||
| << node->DebugString() << "]"; | |||||
| } | |||||
| auto kernel_info = node->kernel_info(); | auto kernel_info = node->kernel_info(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info); | MS_EXCEPTION_IF_NULL(kernel_info); | ||||
| auto build_info = kernel_info->select_kernel_build_info(); | auto build_info = kernel_info->select_kernel_build_info(); | ||||
| @@ -364,62 +374,60 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNo | |||||
| std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) { | std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) { | ||||
| auto format = GetOutputFormat(node, output_idx); | auto format = GetOutputFormat(node, output_idx); | ||||
| auto infer_shape = GetOutputInferShape(node, output_idx); | auto infer_shape = GetOutputInferShape(node, output_idx); | ||||
| // if format is default_format or NC1KHKWHWC0,device shape = original shape | |||||
| if (format == kOpFormat_DEFAULT || format == kOpFormat_NC1KHKWHWC0) { | |||||
| return infer_shape; | |||||
| } | |||||
| // scalar shape | |||||
| if (infer_shape.empty()) { | if (infer_shape.empty()) { | ||||
| return infer_shape; | return infer_shape; | ||||
| } | } | ||||
| if (format == kOpFormat_FRAC_NZ) { | |||||
| return trans::TransShapeToDevice(infer_shape, format); | |||||
| // if format is default_format or NC1KHKWHWC0,device shape = original shape | |||||
| if (trans::IsNeedPadding(format, infer_shape.size())) { | |||||
| infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx)); | |||||
| } | } | ||||
| // else trans infer shape to 4d and then calculate device shape | |||||
| return trans::TransShapeToDevice(trans::TransShapeTo4d(infer_shape), format); | |||||
| return trans::TransShapeToDevice(infer_shape, format); | |||||
| } | } | ||||
| std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) { | std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) { | ||||
| auto format = GetInputFormat(node, input_idx); | auto format = GetInputFormat(node, input_idx); | ||||
| auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx); | auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx); | ||||
| // if format is default_format or NC1KHKWHWC0,device shape = original shape | |||||
| if (format == kOpFormat_DEFAULT || format == kOpFormat_NC1KHKWHWC0) { | |||||
| return infer_shape; | |||||
| } | |||||
| if (infer_shape.empty()) { | if (infer_shape.empty()) { | ||||
| return infer_shape; | return infer_shape; | ||||
| } | } | ||||
| if (format == kOpFormat_FRAC_NZ) { | |||||
| return trans::TransShapeToDevice(infer_shape, format); | |||||
| // if format is default_format or NC1KHKWHWC0,device shape = original shape | |||||
| if (trans::IsNeedPadding(format, infer_shape.size())) { | |||||
| infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx)); | |||||
| } | } | ||||
| // else trans infer shape to 4d and then calculate device shape | |||||
| return trans::TransShapeToDevice(trans::TransShapeTo4d(infer_shape), format); | |||||
| return trans::TransShapeToDevice(infer_shape, format); | |||||
| } | } | ||||
| std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { | std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (input_idx > GetInputTensorNum(node)) { | |||||
| MS_LOG(EXCEPTION) << "The index:" << input_idx | |||||
| << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node[" | |||||
| << node->DebugString() << "]"; | |||||
| } | |||||
| auto kernel_info = node->kernel_info(); | auto kernel_info = node->kernel_info(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info); | MS_EXCEPTION_IF_NULL(kernel_info); | ||||
| auto build_info = kernel_info->select_kernel_build_info(); | auto build_info = kernel_info->select_kernel_build_info(); | ||||
| MS_EXCEPTION_IF_NULL(build_info); | MS_EXCEPTION_IF_NULL(build_info); | ||||
| std::vector<kernel::Axis> result; | |||||
| if (!build_info->GetInputReshapeType(input_idx, &result)) { | |||||
| MS_LOG(EXCEPTION) << "Failed to get the node's[ " << node->DebugString() << "] reshape type !"; | |||||
| if (build_info->IsInputDefaultPadding()) { | |||||
| return {}; | |||||
| } | } | ||||
| return result; | |||||
| return build_info->GetInputReshapeType(input_idx); | |||||
| } | } | ||||
| std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) { | std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (output_idx > GetOutputTensorNum(node)) { | |||||
| MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " | |||||
| << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]"; | |||||
| } | |||||
| auto kernel_info = node->kernel_info(); | auto kernel_info = node->kernel_info(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info); | MS_EXCEPTION_IF_NULL(kernel_info); | ||||
| auto build_info = kernel_info->select_kernel_build_info(); | auto build_info = kernel_info->select_kernel_build_info(); | ||||
| MS_EXCEPTION_IF_NULL(build_info); | MS_EXCEPTION_IF_NULL(build_info); | ||||
| std::vector<kernel::Axis> result; | |||||
| if (!build_info->GetOutputReshapeType(output_idx, &result)) { | |||||
| MS_LOG(EXCEPTION) << "Failed to get the node's[ " << node->DebugString() << "] reshape type !"; | |||||
| if (build_info->IsOutputDefaultPadding()) { | |||||
| return {}; | |||||
| } | } | ||||
| return result; | |||||
| return build_info->GetOutputReshapeType(output_idx); | |||||
| } | } | ||||
| TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) { | TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) { | ||||
| @@ -465,6 +473,10 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &nod | |||||
| TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) { | TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (output_idx > GetOutputTensorNum(node)) { | |||||
| MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " | |||||
| << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]"; | |||||
| } | |||||
| auto kernel_info = node->kernel_info(); | auto kernel_info = node->kernel_info(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info); | MS_EXCEPTION_IF_NULL(kernel_info); | ||||
| auto build_info = kernel_info->select_kernel_build_info(); | auto build_info = kernel_info->select_kernel_build_info(); | ||||
| @@ -474,6 +486,10 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size | |||||
| TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) { | TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (input_idx > GetInputTensorNum(node)) { | |||||
| MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ " | |||||
| << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]"; | |||||
| } | |||||
| auto kernel_info = node->kernel_info(); | auto kernel_info = node->kernel_info(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info); | MS_EXCEPTION_IF_NULL(kernel_info); | ||||
| auto build_info = kernel_info->select_kernel_build_info(); | auto build_info = kernel_info->select_kernel_build_info(); | ||||
| @@ -498,11 +514,15 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, | |||||
| MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"; | MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"; | ||||
| } | } | ||||
| } | } | ||||
| if (output_idx > GetOutputTensorNum(node)) { | |||||
| MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " | |||||
| << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; | |||||
| } | |||||
| auto kernel_info = node->kernel_info(); | auto kernel_info = node->kernel_info(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info); | MS_EXCEPTION_IF_NULL(kernel_info); | ||||
| auto addr = kernel_info->GetOutputAddr(output_idx); | auto addr = kernel_info->GetOutputAddr(output_idx); | ||||
| if (addr == nullptr) { | if (addr == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "output_idx " << output_idx << " of node " << node->DebugString() | |||||
| MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString() | |||||
| << " output addr is not exist"; | << " output addr is not exist"; | ||||
| } | } | ||||
| return addr; | return addr; | ||||
| @@ -519,11 +539,15 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod | |||||
| MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."; | MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."; | ||||
| } | } | ||||
| } | } | ||||
| if (output_idx > GetOutputTensorNum(node)) { | |||||
| MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " | |||||
| << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; | |||||
| } | |||||
| auto kernel_info = node->kernel_info(); | auto kernel_info = node->kernel_info(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info); | MS_EXCEPTION_IF_NULL(kernel_info); | ||||
| auto addr = kernel_info->GetMutableOutputAddr(output_idx); | auto addr = kernel_info->GetMutableOutputAddr(output_idx); | ||||
| if (addr == nullptr) { | if (addr == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "output_idx" << output_idx << " of node " << node->DebugString() | |||||
| MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString() | |||||
| << " output addr is not exist"; | << " output addr is not exist"; | ||||
| } | } | ||||
| return addr; | return addr; | ||||
| @@ -532,6 +556,10 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod | |||||
| // get output device addr of anf_node | // get output device addr of anf_node | ||||
| bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) { | bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (output_idx > GetOutputTensorNum(node)) { | |||||
| MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " | |||||
| << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; | |||||
| } | |||||
| auto kernel_info = node->kernel_info(); | auto kernel_info = node->kernel_info(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info); | MS_EXCEPTION_IF_NULL(kernel_info); | ||||
| return kernel_info->OutputAddrExist(output_idx); | return kernel_info->OutputAddrExist(output_idx); | ||||
| @@ -771,22 +799,24 @@ AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) | |||||
| return node->input(get_input_index); | return node->input(get_input_index); | ||||
| } | } | ||||
| bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (node->isa<ValueNode>()) { | |||||
| return false; | |||||
| } | |||||
| auto kernel_info = node->kernel_info(); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||||
| return kernel_info->is_feature_map(); | |||||
| } | |||||
| bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) { | bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) { | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature"; | |||||
| MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map"; | |||||
| } | } | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| auto input_node = cnode->input(input_index + 1); | auto input_node = cnode->input(input_index + 1); | ||||
| auto node_with_index = VisitKernel(input_node, 0); | |||||
| MS_EXCEPTION_IF_NULL(node_with_index.first); | |||||
| if (node_with_index.first->isa<ValueNode>()) { | |||||
| return false; | |||||
| } | |||||
| if (node_with_index.first->isa<Parameter>()) { | |||||
| return !AnfAlgo::IsParameterWeight(node_with_index.first->cast<ParameterPtr>()); | |||||
| } | |||||
| return true; | |||||
| return IsFeatureMapOutput(input_node); | |||||
| } | } | ||||
| size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) { | size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) { | ||||
| @@ -102,7 +102,9 @@ class AnfRuntimeAlgorithm { | |||||
| static std::vector<size_t> GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx); | static std::vector<size_t> GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx); | ||||
| // get input shapes which will built and run in device | // get input shapes which will built and run in device | ||||
| static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx); | static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx); | ||||
| // Get Input Padding Axis | |||||
| static std::vector<kernel::Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx); | static std::vector<kernel::Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx); | ||||
| // Get Output Padding Axis | |||||
| static std::vector<kernel::Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx); | static std::vector<kernel::Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx); | ||||
| // get output data type inferred by ME of anf node | // get output data type inferred by ME of anf node | ||||
| static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx); | static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx); | ||||
| @@ -166,6 +168,9 @@ class AnfRuntimeAlgorithm { | |||||
| // get graph id | // get graph id | ||||
| static uint32_t GetGraphId(const AnfNode *node); | static uint32_t GetGraphId(const AnfNode *node); | ||||
| static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index); | static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index); | ||||
| // charge if the node's output is a feature map output | |||||
| static bool IsFeatureMapOutput(const AnfNodePtr &node); | |||||
| // charge if the node's input is from a feature map output | |||||
| static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index); | static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index); | ||||
| // get real input index for some tbe ops which input order is different between me and tbe impl | // get real input index for some tbe ops which input order is different between me and tbe impl | ||||
| static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); | static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "ir/meta_tensor.h" | #include "ir/meta_tensor.h" | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "common/trans.h" | |||||
| #include "device/kernel_runtime.h" | #include "device/kernel_runtime.h" | ||||
| #include "device/ascend/kernel_select_ascend.h" | #include "device/ascend/kernel_select_ascend.h" | ||||
| #include "device/ascend/kernel_build_ascend.h" | #include "device/ascend/kernel_build_ascend.h" | ||||
| @@ -730,8 +731,8 @@ void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor | |||||
| size_t tensor_size = front_tensor->data().nbytes(); | size_t tensor_size = front_tensor->data().nbytes(); | ||||
| auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0); | auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0); | ||||
| MS_EXCEPTION_IF_NULL(addr); | MS_EXCEPTION_IF_NULL(addr); | ||||
| if (!addr->SyncHostToDevice(front_tensor->shape(), tensor_size, front_tensor->data_type(), | |||||
| front_tensor->data_c(false))) { | |||||
| if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size, | |||||
| front_tensor->data_type(), front_tensor->data_c(false))) { | |||||
| MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!"; | MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!"; | ||||
| } | } | ||||
| MS_LOG(INFO) << "Finish!"; | MS_LOG(INFO) << "Finish!"; | ||||
| @@ -143,6 +143,12 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | |||||
| cnode->set_abstract(std::make_shared<abstract::AbstractNone>()); | cnode->set_abstract(std::make_shared<abstract::AbstractNone>()); | ||||
| // create kernel_info from new parameter | // create kernel_info from new parameter | ||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | auto kernel_info = std::make_shared<device::KernelInfo>(); | ||||
| // if the node only has the primitive(such as getNext) or the node's input has a feature map input | |||||
| // then the node's output is a feature map output | |||||
| if (inputs.size() == 1 || std::any_of(inputs.begin() + 1, inputs.end(), | |||||
| [&](const AnfNodePtr &node) { return AnfAlgo::IsFeatureMapOutput(node); })) { | |||||
| kernel_info->SetFeatureMapFlag(true); | |||||
| } | |||||
| cnode->set_kernel_info(kernel_info); | cnode->set_kernel_info(kernel_info); | ||||
| AnfAlgo::SetGraphId(graph_id_, cnode.get()); | AnfAlgo::SetGraphId(graph_id_, cnode.get()); | ||||
| return cnode; | return cnode; | ||||
| @@ -162,22 +168,26 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) { | |||||
| ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { | ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { | ||||
| ParameterPtr new_parameter = add_parameter(); | ParameterPtr new_parameter = add_parameter(); | ||||
| MS_EXCEPTION_IF_NULL(new_parameter); | MS_EXCEPTION_IF_NULL(new_parameter); | ||||
| // create kernel_info form new parameter | |||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||||
| size_t output_tensor_num = 1; | size_t output_tensor_num = 1; | ||||
| // if use default parameter = nullptr,it remarks create a new parameter from no parameter | // if use default parameter = nullptr,it remarks create a new parameter from no parameter | ||||
| if (parameter == nullptr) { | if (parameter == nullptr) { | ||||
| new_parameter->set_abstract(std::make_shared<abstract::AbstractNone>()); | new_parameter->set_abstract(std::make_shared<abstract::AbstractNone>()); | ||||
| kernel_info->SetFeatureMapFlag(true); | |||||
| } else { | } else { | ||||
| // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter | // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter | ||||
| new_parameter->set_abstract(parameter->abstract()); | new_parameter->set_abstract(parameter->abstract()); | ||||
| new_parameter->set_name(parameter->name()); | new_parameter->set_name(parameter->name()); | ||||
| if (parameter->has_default()) { | |||||
| if (AnfAlgo::IsParameterWeight(parameter)) { | |||||
| new_parameter->set_default_param(parameter->default_param()); | new_parameter->set_default_param(parameter->default_param()); | ||||
| kernel_info->SetFeatureMapFlag(false); | |||||
| } else { | |||||
| kernel_info->SetFeatureMapFlag(true); | |||||
| } | } | ||||
| // if output is a tuple tensor,now can use for loop to handle tuple tensor | // if output is a tuple tensor,now can use for loop to handle tuple tensor | ||||
| output_tensor_num = AnfAlgo::GetOutputTensorNum(parameter); | output_tensor_num = AnfAlgo::GetOutputTensorNum(parameter); | ||||
| } | } | ||||
| // create kernel_info form new parameter | |||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||||
| new_parameter->set_kernel_info(kernel_info); | new_parameter->set_kernel_info(kernel_info); | ||||
| // create kernel_build_info for new parameter | // create kernel_build_info for new parameter | ||||
| auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | ||||
| @@ -217,6 +227,7 @@ std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNo | |||||
| AddValueNodeToGraph(new_value_node); | AddValueNodeToGraph(new_value_node); | ||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | auto kernel_info = std::make_shared<device::KernelInfo>(); | ||||
| new_value_node->set_kernel_info(kernel_info); | new_value_node->set_kernel_info(kernel_info); | ||||
| kernel_info->SetFeatureMapFlag(false); | |||||
| // create kernel_build_info for new value node | // create kernel_build_info for new value node | ||||
| auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | ||||
| // set the format of value_node to DEFAULT_FORMAT | // set the format of value_node to DEFAULT_FORMAT | ||||
| @@ -240,6 +251,7 @@ ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) { | |||||
| new_value_node->set_abstract(value_node->abstract()); | new_value_node->set_abstract(value_node->abstract()); | ||||
| // create kernel_info fo new value node | // create kernel_info fo new value node | ||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | auto kernel_info = std::make_shared<device::KernelInfo>(); | ||||
| kernel_info->SetFeatureMapFlag(false); | |||||
| new_value_node->set_kernel_info(kernel_info); | new_value_node->set_kernel_info(kernel_info); | ||||
| // create kernel_build_info for new value node | // create kernel_build_info for new value node | ||||
| auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include "pipeline/parse/data_converter.h" | #include "pipeline/parse/data_converter.h" | ||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "common/trans.h" | |||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "utils/config_manager.h" | #include "utils/config_manager.h" | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| @@ -124,7 +125,8 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne | |||||
| MS_EXCEPTION_IF_NULL(ms_context); | MS_EXCEPTION_IF_NULL(ms_context); | ||||
| if (ms_context->enable_pynative_infer()) { | if (ms_context->enable_pynative_infer()) { | ||||
| tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index)); | tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index)); | ||||
| } else if (!address->SyncDeviceToHost(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||||
| } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), | |||||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||||
| tensor->data_c(true))) { | tensor->data_c(true))) { | ||||
| MS_LOG(INFO) << "output sync device to host error!!!"; | MS_LOG(INFO) << "output sync device to host error!!!"; | ||||
| tensor->set_dirty(false); | tensor->set_dirty(false); | ||||
| @@ -369,7 +371,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, | |||||
| kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{input_tensor->device_address()->type_id()}); | kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{input_tensor->device_address()->type_id()}); | ||||
| } | } | ||||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get()); | AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get()); | ||||
| // construct abstract of parameter | |||||
| // ftruct abstract of parameter | |||||
| auto abstract = std::make_shared<abstract::AbstractTensor>(input_tensor); | auto abstract = std::make_shared<abstract::AbstractTensor>(input_tensor); | ||||
| param->set_abstract(abstract); | param->set_abstract(abstract); | ||||
| return param; | return param; | ||||
| @@ -548,7 +550,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap | |||||
| if (need_sync) { | if (need_sync) { | ||||
| tensor->set_device_address(device_address); | tensor->set_device_address(device_address); | ||||
| MS_EXCEPTION_IF_NULL(device_address); | MS_EXCEPTION_IF_NULL(device_address); | ||||
| if (!device_address->SyncHostToDevice(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||||
| if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), | |||||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||||
| tensor->data_c(false))) { | tensor->data_c(false))) { | ||||
| MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; | MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; | ||||
| } | } | ||||
| @@ -620,8 +623,8 @@ void SessionBasic::Summary(KernelGraph *graph) { | |||||
| (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); | (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); | ||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); | tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); | ||||
| MS_EXCEPTION_IF_NULL(address); | MS_EXCEPTION_IF_NULL(address); | ||||
| if (!address->SyncDeviceToHost(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||||
| tensor->data_c(true))) { | |||||
| if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()), | |||||
| tensor->data_type(), tensor->data_c(true))) { | |||||
| MS_LOG(ERROR) << "Failed to sync output from device to host."; | MS_LOG(ERROR) << "Failed to sync output from device to host."; | ||||
| } | } | ||||
| tensor->set_dirty(false); | tensor->set_dirty(false); | ||||
| @@ -197,8 +197,8 @@ const std::set<std::string> kOptOperatorSet = { | |||||
| kApplyRMSPropOpName, | kApplyRMSPropOpName, | ||||
| }; | }; | ||||
| const std::set<std::string> kSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, | |||||
| kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0}; | |||||
| const std::set<std::string> kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, | |||||
| kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0}; | |||||
| static inline void ChangeFileMode(const std::string& file_name, mode_t mode) { | static inline void ChangeFileMode(const std::string& file_name, mode_t mode) { | ||||
| if (access(file_name.c_str(), F_OK) != 0) { | if (access(file_name.c_str(), F_OK) != 0) { | ||||
| @@ -80,6 +80,8 @@ TEST_F(TestHWLayerNormBetaGammaBackpropFusion, layernorm_beta_gamma_backprop_fus | |||||
| builder1.SetOutputsDeviceType({kNumberTypeFloat32}); | builder1.SetOutputsDeviceType({kNumberTypeFloat32}); | ||||
| cast0->set_kernel_info(std::make_shared<device::KernelInfo>()); | cast0->set_kernel_info(std::make_shared<device::KernelInfo>()); | ||||
| cast1->set_kernel_info(std::make_shared<device::KernelInfo>()); | cast1->set_kernel_info(std::make_shared<device::KernelInfo>()); | ||||
| cast0->set_abstract(x_abstract); | |||||
| cast1->set_abstract(x_abstract); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast0.get()); | AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast0.get()); | ||||
| AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast1.get()); | AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast1.get()); | ||||
| @@ -211,8 +211,8 @@ TEST_F(AnfRuntimeAlgorithmTest, EraseNodeAttr) { | |||||
| TEST_F(AnfRuntimeAlgorithmTest, GetInputTensorNum) { | TEST_F(AnfRuntimeAlgorithmTest, GetInputTensorNum) { | ||||
| auto kernel_graph = std::make_shared<KernelGraph>(); | auto kernel_graph = std::make_shared<KernelGraph>(); | ||||
| // test cnode node | // test cnode node | ||||
| auto parameter_one = kernel_graph->add_parameter(); | |||||
| auto parameter_two = kernel_graph->add_parameter(); | |||||
| auto parameter_one = kernel_graph->NewParameter(); | |||||
| auto parameter_two = kernel_graph->NewParameter(); | |||||
| std::vector<AnfNodePtr> add_inputs{NewValueNode(prim::kPrimTensorAdd), parameter_one, parameter_two}; | std::vector<AnfNodePtr> add_inputs{NewValueNode(prim::kPrimTensorAdd), parameter_one, parameter_two}; | ||||
| auto add = kernel_graph->NewCNode(add_inputs); | auto add = kernel_graph->NewCNode(add_inputs); | ||||
| EXPECT_EQ(AnfAlgo::GetInputTensorNum(add), 2); | EXPECT_EQ(AnfAlgo::GetInputTensorNum(add), 2); | ||||
| @@ -247,9 +247,11 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputTensorNum) { | |||||
| TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) { | TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) { | ||||
| auto kernel_graph = std::make_shared<KernelGraph>(); | auto kernel_graph = std::make_shared<KernelGraph>(); | ||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.push_back(NewValueNode(prim::kPrimTensorAdd)); | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim::kPrimTensorAdd), kernel_graph->NewParameter(), | |||||
| kernel_graph->NewParameter()}; | |||||
| auto add = kernel_graph->NewCNode(inputs); | auto add = kernel_graph->NewCNode(inputs); | ||||
| std::vector<size_t> shape = {1, 2, 3, 4}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, {shape, shape}, add.get()); | |||||
| MS_EXCEPTION_IF_NULL(add); | MS_EXCEPTION_IF_NULL(add); | ||||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | add->set_kernel_info(std::make_shared<KernelInfo>()); | ||||
| auto d_kernel_info = add->kernel_info(); | auto d_kernel_info = add->kernel_info(); | ||||
| @@ -266,8 +268,8 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) { | |||||
| TEST_F(AnfRuntimeAlgorithmTest, GetInputFormat) { | TEST_F(AnfRuntimeAlgorithmTest, GetInputFormat) { | ||||
| auto kernel_graph = std::make_shared<KernelGraph>(); | auto kernel_graph = std::make_shared<KernelGraph>(); | ||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.push_back(NewValueNode(prim::kPrimTensorAdd)); | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim::kPrimTensorAdd), kernel_graph->NewParameter(), | |||||
| kernel_graph->NewParameter()}; | |||||
| auto add = kernel_graph->NewCNode(inputs); | auto add = kernel_graph->NewCNode(inputs); | ||||
| MS_EXCEPTION_IF_NULL(add); | MS_EXCEPTION_IF_NULL(add); | ||||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | add->set_kernel_info(std::make_shared<KernelInfo>()); | ||||
| @@ -345,7 +347,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputInferShape) { | |||||
| std::vector<int> shp{2, 32, 224, 224}; | std::vector<int> shp{2, 32, 224, 224}; | ||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | ||||
| // test parameter node as input | // test parameter node as input | ||||
| auto parameter_node = kernel_graph->add_parameter(); | |||||
| auto parameter_node = kernel_graph->NewParameter(); | |||||
| MS_EXCEPTION_IF_NULL(parameter_node); | MS_EXCEPTION_IF_NULL(parameter_node); | ||||
| parameter_node->set_abstract(x_abstract); | parameter_node->set_abstract(x_abstract); | ||||
| EXPECT_THROW(AnfAlgo::GetPrevNodeOutputInferShape(parameter_node, 0), std::runtime_error); | EXPECT_THROW(AnfAlgo::GetPrevNodeOutputInferShape(parameter_node, 0), std::runtime_error); | ||||
| @@ -387,13 +389,13 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceShape) { | |||||
| auto kernel_graph = std::make_shared<KernelGraph>(); | auto kernel_graph = std::make_shared<KernelGraph>(); | ||||
| std::vector<int> shp{2, 32, 224, 224}; | std::vector<int> shp{2, 32, 224, 224}; | ||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | ||||
| auto parameter_one = kernel_graph->add_parameter(); | |||||
| auto parameter_one = kernel_graph->NewParameter(); | |||||
| MS_EXCEPTION_IF_NULL(parameter_one); | MS_EXCEPTION_IF_NULL(parameter_one); | ||||
| parameter_one->set_abstract(x_abstract); | parameter_one->set_abstract(x_abstract); | ||||
| auto parameter_two = kernel_graph->add_parameter(); | |||||
| auto parameter_two = kernel_graph->NewParameter(); | |||||
| MS_EXCEPTION_IF_NULL(parameter_two); | MS_EXCEPTION_IF_NULL(parameter_two); | ||||
| parameter_two->set_abstract(x_abstract); | parameter_two->set_abstract(x_abstract); | ||||
| auto parameter_third = kernel_graph->add_parameter(); | |||||
| auto parameter_third = kernel_graph->NewParameter(); | |||||
| MS_EXCEPTION_IF_NULL(parameter_third); | MS_EXCEPTION_IF_NULL(parameter_third); | ||||
| parameter_third->set_abstract(x_abstract); | parameter_third->set_abstract(x_abstract); | ||||
| // test cnode as input | // test cnode as input | ||||
| @@ -466,8 +468,8 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceDataTypeTest) { | |||||
| TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceDataTypeTest) { | TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceDataTypeTest) { | ||||
| auto kernel_graph = std::make_shared<KernelGraph>(); | auto kernel_graph = std::make_shared<KernelGraph>(); | ||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.push_back(NewValueNode(prim::kPrimTensorAdd)); | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim::kPrimTensorAdd), kernel_graph->NewParameter(), | |||||
| kernel_graph->NewParameter()}; | |||||
| auto add = kernel_graph->NewCNode(inputs); | auto add = kernel_graph->NewCNode(inputs); | ||||
| MS_EXCEPTION_IF_NULL(add); | MS_EXCEPTION_IF_NULL(add); | ||||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | add->set_kernel_info(std::make_shared<KernelInfo>()); | ||||
| @@ -140,11 +140,11 @@ TEST_F(KernelGraphTest, SetExecOrderByDefault) { | |||||
| std::vector<int> shape = {2, 32, 224, 224}; | std::vector<int> shape = {2, 32, 224, 224}; | ||||
| auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape); | auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape); | ||||
| auto x_parameter = kernel_graph->add_parameter(); | |||||
| auto x_parameter = kernel_graph->NewParameter(); | |||||
| MS_EXCEPTION_IF_NULL(x_parameter); | MS_EXCEPTION_IF_NULL(x_parameter); | ||||
| x_parameter->set_name("x_parameter"); | x_parameter->set_name("x_parameter"); | ||||
| x_parameter->set_abstract(abstract); | x_parameter->set_abstract(abstract); | ||||
| auto y_parameter = kernel_graph->add_parameter(); | |||||
| auto y_parameter = kernel_graph->NewParameter(); | |||||
| MS_EXCEPTION_IF_NULL(y_parameter); | MS_EXCEPTION_IF_NULL(y_parameter); | ||||
| y_parameter->set_name("y_parameter"); | y_parameter->set_name("y_parameter"); | ||||
| y_parameter->set_abstract(abstract); | y_parameter->set_abstract(abstract); | ||||
| @@ -153,7 +153,7 @@ TEST_F(KernelGraphTest, SetExecOrderByDefault) { | |||||
| MS_EXCEPTION_IF_NULL(add); | MS_EXCEPTION_IF_NULL(add); | ||||
| add->set_abstract(abstract); | add->set_abstract(abstract); | ||||
| auto z_parameter = kernel_graph->add_parameter(); | |||||
| auto z_parameter = kernel_graph->NewParameter(); | |||||
| MS_EXCEPTION_IF_NULL(z_parameter); | MS_EXCEPTION_IF_NULL(z_parameter); | ||||
| z_parameter->set_name("z_parameter"); | z_parameter->set_name("z_parameter"); | ||||
| z_parameter->set_abstract(abstract); | z_parameter->set_abstract(abstract); | ||||