Merge pull request !260 from lianliguang/refactor-padding-strategytags/v0.2.0-alpha
| @@ -20,6 +20,8 @@ | |||
| #include <utility> | |||
| #include "./securec.h" | |||
| #include "common/utils.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "kernel/kernel.h" | |||
| #include "device/convert_tensor_utils.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -27,6 +29,33 @@ | |||
| namespace mindspore { | |||
| namespace trans { | |||
| namespace { | |||
| std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) { | |||
| std::vector<size_t> shape_4d(4, 1); | |||
| switch (shape.size()) { | |||
| case 0: | |||
| return shape_4d; | |||
| case 1: | |||
| shape_4d[1] = shape[0]; | |||
| break; | |||
| case 2: | |||
| shape_4d[1] = shape[0]; | |||
| shape_4d[2] = shape[1]; | |||
| break; | |||
| case 3: | |||
| shape_4d[1] = shape[0]; | |||
| shape_4d[2] = shape[1]; | |||
| shape_4d[3] = shape[2]; | |||
| break; | |||
| case 4: | |||
| std::copy(shape.begin(), shape.end(), shape_4d.begin()); | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size(); | |||
| } | |||
| return shape_4d; | |||
| } | |||
| } // namespace | |||
| const size_t kNchwDims = 4; | |||
| const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1}, | |||
| {kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8}, | |||
| @@ -154,38 +183,64 @@ size_t TypeIdSize(const TypeId data_type) { | |||
| return unsupported_type_error; | |||
| } | |||
| std::vector<size_t> TransShapeTo4d(const std::vector<size_t> &shape) { | |||
| bool IsNeedPadding(const std::string &format, const size_t shape_size) { | |||
| if (shape_size == 0) { | |||
| return false; | |||
| } | |||
| if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) { | |||
| return false; | |||
| } else if (shape_size < 4) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) { | |||
| std::vector<int> shape; | |||
| std::vector<size_t> host_shape; | |||
| if (node->isa<ValueNode>()) { | |||
| auto value_node = node->cast<ValueNodePtr>(); | |||
| auto node_value = value_node->value(); | |||
| auto tensor = node_value->cast<tensor::TensorPtr>(); | |||
| if (tensor == nullptr) { | |||
| MS_LOG(EXCEPTION) << " the node[ " << node->DebugString() << "]'s cannot convert "; | |||
| } | |||
| shape = tensor->shape(); | |||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); | |||
| if (host_shape.empty()) { | |||
| host_shape.push_back(1); | |||
| } | |||
| } else { | |||
| host_shape = AnfAlgo::GetOutputInferShape(node, index); | |||
| } | |||
| if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, 0), host_shape.size())) { | |||
| host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, 0)); | |||
| } | |||
| std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt); | |||
| return shape; | |||
| } | |||
| std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<kernel::Axis> &padding_axis) { | |||
| if (padding_axis.empty() || shape.size() != padding_axis.size()) { | |||
| return PaddingShapeTo4dByDefault(shape); | |||
| } | |||
| std::vector<size_t> shape_4d(4, 1); | |||
| switch (shape.size()) { | |||
| case 0: | |||
| break; | |||
| case 1: | |||
| shape_4d[1] = shape[0]; | |||
| break; | |||
| case 2: | |||
| shape_4d[0] = shape[0]; | |||
| shape_4d[1] = shape[1]; | |||
| break; | |||
| case 3: | |||
| MS_LOG(EXCEPTION) << "Unexpected shape size = 3,it should has a default format"; | |||
| case 4: | |||
| for (size_t i = 0; i < 4; ++i) { | |||
| shape_4d[i] = shape[i]; | |||
| } | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size(); | |||
| for (size_t index = 0; index < padding_axis.size(); index++) { | |||
| shape_4d[padding_axis[index]] = shape[index]; | |||
| } | |||
| return shape_4d; | |||
| } | |||
| std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) { | |||
| if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { | |||
| return shape; | |||
| } | |||
| auto temp_shape = shape; | |||
| std::vector<size_t> device_shape; | |||
| if (format == kOpFormat_FRAC_NZ) { | |||
| if (shape.size() < 2) { | |||
| MS_EXCEPTION(NotSupportError) << "Format " << format << " is not support shape " << shape.size(); | |||
| } | |||
| if (shape.size() > 2) { | |||
| MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size(); | |||
| } else { | |||
| (void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape)); | |||
| } | |||
| auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1; | |||
| @@ -197,35 +252,36 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s | |||
| return device_shape; | |||
| } | |||
| if (shape.size() != 4) { | |||
| MS_LOG(EXCEPTION) << "shape_4d size should be 4"; | |||
| MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; | |||
| temp_shape = PaddingShapeTo4dByDefault(shape); | |||
| } | |||
| if (format == kOpFormat_NC1HWC0) { | |||
| size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; | |||
| size_t C1 = (temp_shape[1] + kCubeSize - 1) / kCubeSize; | |||
| size_t C0 = kCubeSize; | |||
| device_shape.push_back(shape[0]); | |||
| device_shape.push_back(temp_shape[0]); | |||
| device_shape.push_back(C1); | |||
| device_shape.push_back(shape[2]); | |||
| device_shape.push_back(shape[3]); | |||
| device_shape.push_back(temp_shape[2]); | |||
| device_shape.push_back(temp_shape[3]); | |||
| device_shape.push_back(C0); | |||
| return device_shape; | |||
| } else if (format == kOpFormat_FRAC_Z) { | |||
| size_t cout16 = ((shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize; | |||
| size_t cin16 = ((shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize; | |||
| device_shape.push_back(shape[2] * shape[3] * cin16 / kCubeSize); | |||
| size_t cout16 = ((temp_shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize; | |||
| size_t cin16 = ((temp_shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize; | |||
| device_shape.push_back(temp_shape[2] * temp_shape[3] * cin16 / kCubeSize); | |||
| device_shape.push_back(cout16 / kCubeSize); | |||
| device_shape.push_back(kCubeSize); | |||
| device_shape.push_back(kCubeSize); | |||
| return device_shape; | |||
| } else if (format == kOpFormat_NHWC) { | |||
| device_shape.push_back(shape[0]); | |||
| device_shape.push_back(shape[2]); | |||
| device_shape.push_back(shape[3]); | |||
| device_shape.push_back(shape[1]); | |||
| device_shape.push_back(temp_shape[0]); | |||
| device_shape.push_back(temp_shape[2]); | |||
| device_shape.push_back(temp_shape[3]); | |||
| device_shape.push_back(temp_shape[1]); | |||
| return device_shape; | |||
| } else if (format == kOpFormat_NCHW) { | |||
| return shape; | |||
| } else if (format == kOpFormat_HWCN) { | |||
| return {shape[2], shape[3], shape[1], shape[0]}; | |||
| return {temp_shape[2], temp_shape[3], temp_shape[1], temp_shape[0]}; | |||
| } else if (format == kOpFormat_NCHW) { | |||
| return temp_shape; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]"; | |||
| } | |||
| @@ -24,6 +24,7 @@ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "ir/dtype.h" | |||
| #include "kernel/kernel.h" | |||
| #include "ir/dtype/type.h" | |||
| namespace mindspore { | |||
| @@ -49,7 +50,10 @@ size_t TypeIdSize(const TypeId data_type); | |||
| size_t ShapeSize(const std::vector<size_t> &shape); | |||
| size_t CubeSizeByType(const TypeId data_type); | |||
| std::vector<size_t> TransShapeTo4d(const std::vector<size_t> &shape); | |||
| std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, | |||
| const std::vector<kernel::Axis> &padding_axis = {}); | |||
| std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index); | |||
| bool IsNeedPadding(const std::string &format, const size_t shape_size); | |||
| std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format); | |||
| bool TransDataType(const TypeIdArgs &args, void *result); | |||
| bool TransFormat(const FormatArgs &args, void *result); | |||
| @@ -141,7 +141,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int | |||
| if (format_ == kOpFormat_FRAC_NZ) { | |||
| device_shape = trans::TransShapeToDevice(host_shape, format_); | |||
| } else { | |||
| host_shape = trans::TransShapeTo4d(host_shape); | |||
| host_shape = trans::PaddingShapeTo4d(host_shape); | |||
| device_shape = trans::TransShapeToDevice(host_shape, format_); | |||
| } | |||
| if (type_id_ != type) { | |||
| @@ -224,7 +224,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int | |||
| if (format_ == kOpFormat_FRAC_NZ) { | |||
| device_shape = trans::TransShapeToDevice(host_shape, format_); | |||
| } else { | |||
| host_shape = trans::TransShapeTo4d(host_shape); | |||
| host_shape = trans::PaddingShapeTo4d(host_shape); | |||
| device_shape = trans::TransShapeToDevice(host_shape, format_); | |||
| } | |||
| if (type_id_ != type) { | |||
| @@ -27,6 +27,7 @@ | |||
| #include "utils/context/ms_context.h" | |||
| #include "device/ascend/profiling/profiling_manager.h" | |||
| #include "hccl/hcom.h" | |||
| #include "common/trans.h" | |||
| #include "runtime/context.h" | |||
| #include "device/ascend/ascend_stream_assign.h" | |||
| #include "device/ascend/ascend_memory_pool.h" | |||
| @@ -150,7 +151,7 @@ void DumpOutput(mindspore::session::KernelGraph *graph, const string &dump_path, | |||
| auto output_size = AnfAlgo::GetOutputTensorNum(node); | |||
| for (size_t j = 0; j < output_size; ++j) { | |||
| auto addr = AnfAlgo::GetOutputAddr(node, j); | |||
| auto shape = AnfAlgo::GetOutputInferShape(node, j); | |||
| auto shape = trans::GetRuntimePaddingShape(node, j); | |||
| auto type = AnfAlgo::GetOutputInferDataType(node, j); | |||
| auto format = kOpFormat_DEFAULT; | |||
| string filepath = dump_path + '/' + kernel_name + '_' + "output_" + std::to_string(j); | |||
| @@ -181,7 +182,7 @@ void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_p | |||
| continue; | |||
| } | |||
| auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX); | |||
| auto shape = AnfAlgo::GetOutputInferShape(item, PRAMATER_OUTPUT_INDEX); | |||
| auto shape = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX); | |||
| auto type = AnfAlgo::GetOutputInferDataType(item, PRAMATER_OUTPUT_INDEX); | |||
| auto format = kOpFormat_DEFAULT; | |||
| string filepath = dump_path + '/' + parameter_name + '_' + "output_0"; | |||
| @@ -184,7 +184,7 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons | |||
| } | |||
| if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { | |||
| if (AnfAlgo::IsFeatureMapInput(kernel_node, input_index) && | |||
| kSpecialFormatSet.find(kernel_build_info.GetInputFormat(input_index)) != kSpecialFormatSet.end()) { | |||
| kNeedTransFormatSet.find(kernel_build_info.GetInputFormat(input_index)) != kNeedTransFormatSet.end()) { | |||
| (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT]++; | |||
| } | |||
| (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++; | |||
| @@ -210,19 +210,22 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons | |||
| (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++; | |||
| } | |||
| } | |||
| } // namespace | |||
| } | |||
| void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||
| auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); | |||
| MS_EXCEPTION_IF_NULL(input_kernel_node); | |||
| if (AnfAlgo::IsFeatureMapInput(kernel_node, input_index)) { | |||
| continue; | |||
| } | |||
| auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); | |||
| MS_EXCEPTION_IF_NULL(input_with_index.first); | |||
| auto real_input_node = input_with_index.first; | |||
| if (real_input_node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) { | |||
| continue; | |||
| } | |||
| std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder = | |||
| std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| // we set special device info of a input tensor. | |||
| @@ -25,6 +25,7 @@ | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "common/trans.h" | |||
| #include "utils/config_manager.h" | |||
| #include "common/utils.h" | |||
| #include "kernel/kernel_build_info.h" | |||
| @@ -391,7 +392,8 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::Context> &c | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| tensor->set_device_address(device_address); | |||
| if (!device_address->SyncHostToDevice(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), | |||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| tensor->data_c(false))) { | |||
| MS_LOG(INFO) << "SyncHostToDevice failed."; | |||
| return false; | |||
| @@ -31,6 +31,7 @@ class KernelInfo { | |||
| public: | |||
| KernelInfo() { | |||
| kernel_mod_ = nullptr; | |||
| is_feature_map_ = false; | |||
| select_kernel_build_info_ = nullptr; | |||
| output_address_list_ = {}; | |||
| workspace_address_list_ = {}; | |||
| @@ -45,6 +46,7 @@ class KernelInfo { | |||
| void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | |||
| select_kernel_build_info_ = select_kernel_build_info; | |||
| } | |||
| void SetFeatureMapFlag(bool flag) { is_feature_map_ = flag; } | |||
| const DeviceAddress *GetOutputAddr(size_t index) const; | |||
| DeviceAddressPtr GetMutableOutputAddr(size_t index) const; | |||
| bool OutputAddrExist(size_t index) const; | |||
| @@ -63,8 +65,10 @@ class KernelInfo { | |||
| void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } | |||
| uint32_t graph_id() const { return graph_id_; } | |||
| bool operator==(const KernelInfo &other) const; | |||
| bool is_feature_map() const { return is_feature_map_; } | |||
| private: | |||
| bool is_feature_map_; | |||
| kernel::KernelBuildInfoPtr select_kernel_build_info_; | |||
| std::vector<std::shared_ptr<DeviceAddress>> output_address_list_; | |||
| std::vector<std::shared_ptr<DeviceAddress>> workspace_address_list_; | |||
| @@ -105,7 +105,7 @@ size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &nod | |||
| std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index); | |||
| auto format = AnfAlgo::GetOutputFormat(node, output_index); | |||
| if (shape.empty() && format != kOpFormat_DEFAULT) { | |||
| shape = trans::TransShapeTo4d(shape); | |||
| shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index)); | |||
| shape = trans::TransShapeToDevice(shape, format); | |||
| } | |||
| // scalar's output shape is a empty vector | |||
| @@ -401,8 +401,9 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const | |||
| auto address = CreateDeviceAddress(ptr, node_size, AnfAlgo::GetOutputFormat(value_node, output_idx), output_type_id); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); | |||
| if (!address->SyncHostToDevice(tensor->shape(), tensor_size, tensor->data_type(), tensor->data_c(false))) { | |||
| MS_EXCEPTION(NotExistsError) << "kValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is" | |||
| if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(), | |||
| tensor->data_c(false))) { | |||
| MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is" | |||
| << AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is " | |||
| << AnfAlgo::GetOutputInferDataType(value_node, output_idx); | |||
| } | |||
| @@ -421,19 +422,6 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(node_value); | |||
| if (node_value->isa<Tensor>()) { | |||
| AssignValueNodeTensor(value_node, node_value, 0); | |||
| } else if (node_value->isa<ValueTuple>()) { | |||
| auto value_tuple = node_value->cast<ValueTuplePtr>(); | |||
| if (value_tuple == nullptr) { | |||
| MS_LOG(WARNING) << "value_tuple is null"; | |||
| continue; | |||
| } | |||
| size_t i = 0; | |||
| auto value_list = value_tuple->value(); | |||
| for (auto value_ptr : value_list) { | |||
| if (value_ptr->isa<Tensor>()) { | |||
| AssignValueNodeTensor(value_node, value_ptr, i++); | |||
| } | |||
| } | |||
| } else if (node_value->isa<StringImm>()) { | |||
| auto value = GetValue<std::string>(node_value); | |||
| size_t tensor_size = value.size(); | |||
| @@ -59,30 +59,20 @@ size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); } | |||
| size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); } | |||
| bool KernelBuildInfo::GetInputReshapeType(size_t input_index, std::vector<Axis> *reshape_type) const { | |||
| MS_EXCEPTION_IF_NULL(reshape_type); | |||
| reshape_type->clear(); | |||
| std::vector<Axis> KernelBuildInfo::GetInputReshapeType(size_t input_index) const { | |||
| if (input_index >= input_reshape_type_.size()) { | |||
| MS_LOG(WARNING) << "The index [" << input_index << "] is exceed the number of input node size " | |||
| << input_reshape_type_.size(); | |||
| return false; | |||
| MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size " | |||
| << input_reshape_type_.size(); | |||
| } | |||
| (void)std::copy(input_reshape_type_[input_index].begin(), input_reshape_type_[input_index].end(), | |||
| std::inserter(*reshape_type, (*reshape_type).begin())); | |||
| return true; | |||
| return input_reshape_type_[input_index]; | |||
| } | |||
| bool KernelBuildInfo::GetOutputReshapeType(size_t output_index, std::vector<Axis> *reshape_type) const { | |||
| MS_EXCEPTION_IF_NULL(reshape_type); | |||
| reshape_type->clear(); | |||
| std::vector<Axis> KernelBuildInfo::GetOutputReshapeType(size_t output_index) const { | |||
| if (output_index >= output_reshape_type_.size()) { | |||
| MS_LOG(WARNING) << "The index [" << output_index << "] is exceed the number of output node dixr" | |||
| << output_reshape_type_.size(); | |||
| return false; | |||
| MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of output node size " | |||
| << output_reshape_type_.size(); | |||
| } | |||
| (void)std::copy(output_reshape_type_[output_index].begin(), output_reshape_type_[output_index].end(), | |||
| std::inserter(*reshape_type, (*reshape_type).begin())); | |||
| return true; | |||
| return output_reshape_type_[output_index]; | |||
| } | |||
| std::string KernelBuildInfo::ToString() const { | |||
| @@ -115,6 +105,10 @@ bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const { | |||
| return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_); | |||
| } | |||
| bool KernelBuildInfo::IsInputDefaultPadding() const { return output_reshape_type_.empty(); } | |||
| bool KernelBuildInfo::IsOutputDefaultPadding() const { return input_reshape_type_.empty(); } | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) { | |||
| MS_EXCEPTION_IF_NULL(kernel_build_info_); | |||
| kernel_build_info_->kernel_type_ = kernel_type; | |||
| @@ -54,9 +54,13 @@ class KernelBuildInfo { | |||
| TypeId GetOutputDeviceType(size_t output_index) const; | |||
| bool GetInputReshapeType(size_t input_index, std::vector<Axis> *reshape_type) const; | |||
| std::vector<Axis> GetInputReshapeType(size_t input_index) const; | |||
| bool GetOutputReshapeType(size_t input_index, std::vector<Axis> *reshape_type) const; | |||
| bool IsInputDefaultPadding() const; | |||
| bool IsOutputDefaultPadding() const; | |||
| std::vector<Axis> GetOutputReshapeType(size_t input_index) const; | |||
| std::vector<std::string> GetAllInputFormats() const; | |||
| @@ -18,20 +18,21 @@ | |||
| #include <set> | |||
| #include "common/trans.h" | |||
| #include "common/utils.h" | |||
| #include "utils/utils.h" | |||
| #include "device/kernel_info.h" | |||
| #include "kernel/oplib/oplib.h" | |||
| #include "operator/ops.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "session/kernel_graph.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; | |||
| namespace { | |||
| kernel::KernelBuildInfoPtr CreateKernelBuildInfo(const std::string &input_format, const std::string &output_format, | |||
| const AnfNodePtr &node, const kernel::KernelBuildInfo ori_build_info) { | |||
| kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, | |||
| const AnfNodePtr &node, | |||
| const kernel::KernelBuildInfo ori_build_info) { | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsFormat({input_format}); | |||
| builder.SetOutputsFormat({output_format}); | |||
| @@ -54,9 +55,11 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||
| CNodePtr trans_node = func_graph->NewCNode(trans_inputs); | |||
| MS_EXCEPTION_IF_NULL(trans_node); | |||
| if (need_padding) { | |||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, | |||
| {trans::TransShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0))}, | |||
| trans_node.get()); | |||
| // if need padding we should set the transdata node's shape to the padding shape | |||
| AnfAlgo::SetOutputInferTypeAndShape( | |||
| {AnfAlgo::GetOutputInferDataType(input, 0)}, | |||
| {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), AnfAlgo::GetOutputReshapeType(input, 0))}, | |||
| trans_node.get()); | |||
| } else { | |||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, | |||
| {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); | |||
| @@ -92,9 +95,11 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i | |||
| AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index, | |||
| const KernelSelectPtr &kernel_select) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| bool padding_flag = false; | |||
| auto input_node = AnfAlgo::GetInputNode(node, index); | |||
| if (input_node->isa<ValueNode>() || input_node->isa<Parameter>()) { | |||
| auto node_with_index = AnfAlgo::VisitKernel(input_node, 0); | |||
| MS_EXCEPTION_IF_NULL(node_with_index.first); | |||
| auto real_input = node_with_index.first; | |||
| if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) { | |||
| input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select); | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| AnfAlgo::SetNodeInput(node, input_node, index); | |||
| @@ -106,33 +111,11 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & | |||
| std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index); | |||
| std::string origin_format = kOpFormat_DEFAULT; | |||
| std::string dest_format = AnfAlgo::GetInputFormat(node, index); | |||
| if (dest_format == kOpFormat_C1HWNCoC0) { | |||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||
| AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag, | |||
| origin_format, dest_format, kTransDataOpName, true); | |||
| MS_EXCEPTION_IF_NULL(replace_input); | |||
| return replace_input; | |||
| } | |||
| if (dest_format == kOpFormat_NC1HWC0 && origin_shape.size() > 1) { | |||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||
| AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag, | |||
| origin_format, dest_format, kTransDataOpName, true); | |||
| MS_EXCEPTION_IF_NULL(replace_input); | |||
| MS_LOG(DEBUG) << "Inserted Translate45, index: " << index; | |||
| return replace_input; | |||
| } else if (dest_format == kOpFormat_FRAC_NZ) { | |||
| AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag, | |||
| origin_format, dest_format, kTransDataOpName, true); | |||
| MS_EXCEPTION_IF_NULL(replace_input); | |||
| MS_LOG(DEBUG) << "inserted translate " << AnfAlgo::GetInputFormat(node, index) << " To default, index: " << index; | |||
| return replace_input; | |||
| } else if (dest_format == kOpFormat_FRAC_Z && !origin_shape.empty()) { | |||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||
| AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag, | |||
| origin_format, dest_format, kTransDataOpName, true); | |||
| MS_EXCEPTION_IF_NULL(replace_input); | |||
| MS_LOG(DEBUG) << "Inserted Translate45, index: " << index; | |||
| return replace_input; | |||
| if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { | |||
| MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) | |||
| << " To DefaultFormat , index: " << index; | |||
| return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, origin_format, dest_format, kTransDataOpName, | |||
| true); | |||
| } | |||
| return input_node; | |||
| } | |||
| @@ -140,7 +123,6 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & | |||
| AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const KernelSelectPtr &kernel_select) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| bool padding_flag = false; | |||
| std::string output_format; | |||
| std::vector<size_t> origin_shape; | |||
| if (!AnfAlgo::IsRealKernel(node)) { | |||
| @@ -156,46 +138,14 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An | |||
| } | |||
| std::string origin_format = output_format; | |||
| std::string dest_format = kOpFormat_DEFAULT; | |||
| if (output_format == kOpFormat_C1HWNCoC0) { | |||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||
| AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format, | |||
| dest_format, kTransDataOpName, false); | |||
| MS_EXCEPTION_IF_NULL(replace_input); | |||
| return replace_input; | |||
| } | |||
| if (output_format == kOpFormat_NC1HWC0 && origin_shape.size() > 1) { | |||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||
| AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format, | |||
| dest_format, kTransDataOpName, false); | |||
| MS_EXCEPTION_IF_NULL(replace_output); | |||
| MS_LOG(DEBUG) << "Inserted Trans54"; | |||
| return replace_output; | |||
| } else if (output_format == kOpFormat_FRAC_NZ) { | |||
| AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format, | |||
| dest_format, kTransDataOpName, false); | |||
| MS_EXCEPTION_IF_NULL(replace_output); | |||
| MS_LOG(DEBUG) << "Inserted Translate " << output_format << " To default, index: 0"; | |||
| return replace_output; | |||
| } else if (output_format == kOpFormat_FRAC_Z && !origin_shape.empty()) { | |||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||
| AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format, | |||
| dest_format, kTransDataOpName, false); | |||
| MS_EXCEPTION_IF_NULL(replace_output); | |||
| MS_LOG(DEBUG) << "Inserted Trans54"; | |||
| return replace_output; | |||
| if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { | |||
| MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0"; | |||
| return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, origin_format, dest_format, kTransDataOpName, | |||
| false); | |||
| } | |||
| return node; | |||
| } | |||
| void GetTransDataInputFormat(const AnfNodePtr &node, size_t idx, std::string *input_format) { | |||
| MS_EXCEPTION_IF_NULL(input_format); | |||
| if (AnfAlgo::IsRealKernel(node)) { | |||
| *input_format = AnfAlgo::GetOutputFormat(node, idx); | |||
| } else { | |||
| *input_format = AnfAlgo::GetPrevNodeOutputFormat(node, 0); | |||
| } | |||
| } | |||
| AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const KernelSelectPtr &kernel_select) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -203,46 +153,17 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const | |||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||
| make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) { | |||
| bool padding_flag = false; | |||
| std::string output_format; | |||
| GetTransDataInputFormat(node, output_idx, &output_format); | |||
| std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx); | |||
| if (output_format == kOpFormat_NC1KHKWHWC0) { | |||
| MS_LOG(EXCEPTION) << "got the hw format" << output_format << " when insert the transdata node " | |||
| MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node " | |||
| << node->DebugString(); | |||
| } | |||
| auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); | |||
| std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); | |||
| std::string origin_format = output_format; | |||
| std::string dest_format = kOpFormat_DEFAULT; | |||
| if (output_format == kOpFormat_C1HWNCoC0) { | |||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||
| AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag, | |||
| origin_format, dest_format, kTransDataOpName, false); | |||
| MS_EXCEPTION_IF_NULL(replace_input); | |||
| return replace_input; | |||
| } | |||
| if (output_format == kOpFormat_NC1HWC0 && origin_shape.size() > 1) { | |||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||
| // Insert a 5to4 trans op. | |||
| AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag, | |||
| origin_format, dest_format, kTransDataOpName, false); | |||
| MS_EXCEPTION_IF_NULL(replace_output); | |||
| MS_LOG(DEBUG) << "Inserted Translate54"; | |||
| make_tuple_inputs.push_back(replace_output); | |||
| } else if (output_format == kOpFormat_FRAC_NZ) { | |||
| AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag, | |||
| origin_format, dest_format, kTransDataOpName, false); | |||
| MS_EXCEPTION_IF_NULL(replace_output); | |||
| MS_LOG(DEBUG) << "Inserted Translate " << output_format << " To default, index: " << output_idx; | |||
| make_tuple_inputs.push_back(replace_output); | |||
| } else if (output_format == kOpFormat_FRAC_Z && !origin_shape.empty()) { | |||
| padding_flag = (origin_shape.size() != kShape4dDims); | |||
| AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag, | |||
| origin_format, dest_format, kTransDataOpName, false); | |||
| MS_EXCEPTION_IF_NULL(replace_output); | |||
| MS_LOG(DEBUG) << "Inserted Translate54"; | |||
| make_tuple_inputs.push_back(replace_output); | |||
| if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { | |||
| make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, output_format, | |||
| dest_format, kTransDataOpName, false)); | |||
| } else { | |||
| // No need insert trans op. | |||
| make_tuple_inputs.push_back(tuple_getitem); | |||
| @@ -253,16 +174,17 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const | |||
| } | |||
| } // namespace | |||
| AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const KernelSelectPtr &kernel_select, size_t insert_index, const bool padding_flag, | |||
| const KernelSelectPtr &kernel_select, size_t insert_index, | |||
| const std::string &origin_format, const std::string &dest_format, | |||
| const std::string &op_name, bool is_insert_input) { | |||
| AnfNodePtr trans_node = nullptr; | |||
| AnfNodePtr input_node = nullptr; | |||
| AnfNodePtr input_node = node; | |||
| AnfNodePtr trans_data = nullptr; | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (origin_format.empty() || dest_format.empty()) { | |||
| MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format; | |||
| } | |||
| // if insert transdata for input we need to change the input | |||
| if (is_insert_input) { | |||
| if (!node->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; | |||
| @@ -270,29 +192,34 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| input_node = AnfAlgo::GetInputNode(cnode, insert_index); | |||
| if (padding_flag) { | |||
| auto padd_shape = trans::TransShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0)); | |||
| auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padd_shape); | |||
| trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, padding_flag, op_name); | |||
| } else { | |||
| trans_data = NewTransOpNode(func_graph, input_node, kernel_select, padding_flag, op_name); | |||
| } | |||
| } | |||
| bool need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && | |||
| op_name == kTransDataOpName); | |||
| if (!need_padding) { | |||
| // don't need padding insert transdata only | |||
| trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name); | |||
| trans_node = trans_data; | |||
| } else if (is_insert_input) { | |||
| // if need padding & is input need insert a transdata | |||
| // reshape[padding shape] -> transdata[padding shape] -> node | |||
| auto padding_shape = | |||
| trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0)); | |||
| auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); | |||
| trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, op_name); | |||
| trans_node = trans_data; | |||
| } else { | |||
| input_node = node; | |||
| trans_data = NewTransOpNode(func_graph, input_node, kernel_select, padding_flag, op_name); | |||
| if (padding_flag) { | |||
| auto reshape_node = | |||
| CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); | |||
| trans_node = reshape_node; | |||
| } else { | |||
| trans_node = trans_data; | |||
| } | |||
| // if need padding & is output need insert a transdata | |||
| // node -> transdata[padding shape] -> reshape[ori_shape] | |||
| trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name); | |||
| auto reshape_node = | |||
| CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); | |||
| trans_node = reshape_node; | |||
| } | |||
| // refresh the transdata's format to ori format & dst format | |||
| MS_EXCEPTION_IF_NULL(trans_data); | |||
| MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); | |||
| auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); | |||
| auto kernel_build_info = CreateKernelBuildInfo(origin_format, dest_format, input_node, *trans_ori_build_info); | |||
| auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, *trans_ori_build_info); | |||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get()); | |||
| return trans_node; | |||
| } | |||
| @@ -376,7 +303,17 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | |||
| TypeId origin_type; | |||
| auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); | |||
| if (!AnfAlgo::IsFeatureMapInput(cnode, input_index)) { | |||
| auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); | |||
| auto is_weight_boundary = [](const AnfNodePtr &node) -> bool { | |||
| if (node->isa<ValueNode>()) { | |||
| return true; | |||
| } else if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) { | |||
| return true; | |||
| } | |||
| return false; | |||
| }; | |||
| auto real_input_node = kernel_with_index.first; | |||
| if (is_weight_boundary(real_input_node)) { | |||
| // weight | |||
| origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index); | |||
| } else { | |||
| @@ -48,7 +48,7 @@ class KernelQuery { | |||
| using KernelQueryPtr = std::shared_ptr<KernelQuery>; | |||
| AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const KernelSelectPtr &kernel_select, size_t insert_index, bool padding_flag, | |||
| const KernelSelectPtr &kernel_select, size_t insert_index, | |||
| const std::string &origin_format, const std::string &dest_format, | |||
| const std::string &op_name, bool is_insert_input); | |||
| @@ -105,10 +105,8 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP | |||
| // insert trans | |||
| if (origin_format != cur_format) { | |||
| auto kernel_select = std::make_shared<KernelSelect>(); | |||
| bool need_padding = | |||
| (cur_format == kOpFormat_NC1HWC0 && AnfAlgo::GetOutputInferShape(final_node, 0).size() != kShape4dDims); | |||
| final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, need_padding, cur_format, | |||
| origin_format, kTransDataOpName, false); | |||
| final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, cur_format, origin_format, | |||
| kTransDataOpName, false); | |||
| final_index = 0; | |||
| MS_EXCEPTION_IF_NULL(final_node); | |||
| MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); | |||
| @@ -1,99 +1,99 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/ir_fusion/transdata_split.h" | |||
| #include <set> | |||
| #include "pre_activate/ascend/ascend_helper.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const std::set<std::pair<string, string>> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW}, | |||
| {kOpFormat_NCHW, kOpFormat_C1HWNCoC0}, | |||
| {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, | |||
| {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; | |||
| bool TransDataSplit::Run(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| bool changed = false; | |||
| std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { | |||
| CheckCNodeInputSize(node->cast<CNodePtr>(), kBackendTransDataInputNum); | |||
| if (IsFormatInvaild(node)) { | |||
| changed = DoSplit(func_graph, node); | |||
| } | |||
| } | |||
| } | |||
| return changed; | |||
| } | |||
| bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto input_format = AnfAlgo::GetInputFormat(node, 0); | |||
| auto output_format = AnfAlgo::GetOutputFormat(node, 0); | |||
| auto format_pair = std::make_pair(input_format, output_format); | |||
| return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end(); | |||
| } | |||
| // transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW) | |||
| bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto input_node = node->cast<CNodePtr>()->input(1); | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| auto input_format = AnfAlgo::GetInputFormat(node, 0); | |||
| auto output_format = AnfAlgo::GetOutputFormat(node, 0); | |||
| AnfNodePtr new_transdata_node = nullptr; | |||
| AnfNodePtr new_transpose_node = nullptr; | |||
| AnfNodePtr new_replace_node = nullptr; | |||
| // if output_format=default transdata need split transdata->transpose else transpose->transdata | |||
| if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { | |||
| // trans input_format to hwcn | |||
| new_transdata_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, false, input_format, kOpFormat_HWCN, | |||
| kTransDataOpName, true); | |||
| // trans hwcn to default_format | |||
| new_transpose_node = AddTransOpNodeToGraph(func_graph, new_transdata_node, kernel_select_, 0, false, kOpFormat_HWCN, | |||
| output_format, prim::kPrimTranspose->name(), false); | |||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{3, 2, 0, 1}), new_transpose_node); | |||
| new_replace_node = new_transpose_node; | |||
| } else { | |||
| // trans default to hwcn | |||
| new_transpose_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, false, input_format, kOpFormat_HWCN, | |||
| prim::kPrimTranspose->name(), true); | |||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node); | |||
| // trans hwcn to output_format | |||
| new_transdata_node = AddTransOpNodeToGraph(func_graph, new_transpose_node, kernel_select_, 0, false, kOpFormat_HWCN, | |||
| output_format, kTransDataOpName, false); | |||
| new_replace_node = new_transdata_node; | |||
| } | |||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->AddFuncGraph(func_graph); | |||
| if (!manager->Replace(node, new_replace_node)) { | |||
| MS_LOG(EXCEPTION) << "manager replace node failed"; | |||
| } | |||
| MS_LOG(INFO) << "transdata node:" << cnode->DebugString() << "split success."; | |||
| return true; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/ir_fusion/transdata_split.h" | |||
| #include <set> | |||
| #include "pre_activate/ascend/ascend_helper.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const std::set<std::pair<string, string>> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW}, | |||
| {kOpFormat_NCHW, kOpFormat_C1HWNCoC0}, | |||
| {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, | |||
| {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; | |||
| bool TransDataSplit::Run(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| bool changed = false; | |||
| std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { | |||
| CheckCNodeInputSize(node->cast<CNodePtr>(), kBackendTransDataInputNum); | |||
| if (IsFormatInvaild(node)) { | |||
| changed = DoSplit(func_graph, node); | |||
| } | |||
| } | |||
| } | |||
| return changed; | |||
| } | |||
| bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto input_format = AnfAlgo::GetInputFormat(node, 0); | |||
| auto output_format = AnfAlgo::GetOutputFormat(node, 0); | |||
| auto format_pair = std::make_pair(input_format, output_format); | |||
| return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end(); | |||
| } | |||
| // transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW) | |||
| bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto input_node = node->cast<CNodePtr>()->input(1); | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| auto input_format = AnfAlgo::GetInputFormat(node, 0); | |||
| auto output_format = AnfAlgo::GetOutputFormat(node, 0); | |||
| AnfNodePtr new_transdata_node = nullptr; | |||
| AnfNodePtr new_transpose_node = nullptr; | |||
| AnfNodePtr new_replace_node = nullptr; | |||
| // if output_format=default transdata need split transdata->transpose else transpose->transdata | |||
| if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { | |||
| // trans input_format to hwcn | |||
| new_transdata_node = | |||
| AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN, kTransDataOpName, true); | |||
| // trans hwcn to default_format | |||
| new_transpose_node = AddTransOpNodeToGraph(func_graph, new_transdata_node, kernel_select_, 0, kOpFormat_HWCN, | |||
| output_format, prim::kPrimTranspose->name(), false); | |||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{3, 2, 0, 1}), new_transpose_node); | |||
| new_replace_node = new_transpose_node; | |||
| } else { | |||
| // trans default to hwcn | |||
| new_transpose_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN, | |||
| prim::kPrimTranspose->name(), true); | |||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node); | |||
| // trans hwcn to output_format | |||
| new_transdata_node = AddTransOpNodeToGraph(func_graph, new_transpose_node, kernel_select_, 0, kOpFormat_HWCN, | |||
| output_format, kTransDataOpName, false); | |||
| new_replace_node = new_transdata_node; | |||
| } | |||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->AddFuncGraph(func_graph); | |||
| if (!manager->Replace(node, new_replace_node)) { | |||
| MS_LOG(EXCEPTION) << "Manager replace node failed"; | |||
| } | |||
| MS_LOG(INFO) << "Transdata node:" << cnode->DebugString() << "split success."; | |||
| return true; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -291,6 +291,11 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) { | |||
| std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (output_idx > GetOutputTensorNum(node)) { | |||
| MS_LOG(EXCEPTION) << "Output index:" << output_idx | |||
| << " is out of the node output range :" << GetOutputTensorNum(node) << " #node [" | |||
| << node->DebugString() << "]"; | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| @@ -300,6 +305,11 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t | |||
| std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (input_idx > GetInputTensorNum(node)) { | |||
| MS_LOG(EXCEPTION) << "Input index :" << input_idx | |||
| << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node [" | |||
| << node->DebugString() << "]"; | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| @@ -364,62 +374,60 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNo | |||
| std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) { | |||
| auto format = GetOutputFormat(node, output_idx); | |||
| auto infer_shape = GetOutputInferShape(node, output_idx); | |||
| // if format is default_format or NC1KHKWHWC0,device shape = original shape | |||
| if (format == kOpFormat_DEFAULT || format == kOpFormat_NC1KHKWHWC0) { | |||
| return infer_shape; | |||
| } | |||
| // scalar shape | |||
| if (infer_shape.empty()) { | |||
| return infer_shape; | |||
| } | |||
| if (format == kOpFormat_FRAC_NZ) { | |||
| return trans::TransShapeToDevice(infer_shape, format); | |||
| // if format is default_format or NC1KHKWHWC0,device shape = original shape | |||
| if (trans::IsNeedPadding(format, infer_shape.size())) { | |||
| infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx)); | |||
| } | |||
| // else trans infer shape to 4d and then calculate device shape | |||
| return trans::TransShapeToDevice(trans::TransShapeTo4d(infer_shape), format); | |||
| return trans::TransShapeToDevice(infer_shape, format); | |||
| } | |||
| std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) { | |||
| auto format = GetInputFormat(node, input_idx); | |||
| auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx); | |||
| // if format is default_format or NC1KHKWHWC0,device shape = original shape | |||
| if (format == kOpFormat_DEFAULT || format == kOpFormat_NC1KHKWHWC0) { | |||
| return infer_shape; | |||
| } | |||
| if (infer_shape.empty()) { | |||
| return infer_shape; | |||
| } | |||
| if (format == kOpFormat_FRAC_NZ) { | |||
| return trans::TransShapeToDevice(infer_shape, format); | |||
| // if format is default_format or NC1KHKWHWC0,device shape = original shape | |||
| if (trans::IsNeedPadding(format, infer_shape.size())) { | |||
| infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx)); | |||
| } | |||
| // else trans infer shape to 4d and then calculate device shape | |||
| return trans::TransShapeToDevice(trans::TransShapeTo4d(infer_shape), format); | |||
| return trans::TransShapeToDevice(infer_shape, format); | |||
| } | |||
| std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (input_idx > GetInputTensorNum(node)) { | |||
| MS_LOG(EXCEPTION) << "The index:" << input_idx | |||
| << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node[" | |||
| << node->DebugString() << "]"; | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(build_info); | |||
| std::vector<kernel::Axis> result; | |||
| if (!build_info->GetInputReshapeType(input_idx, &result)) { | |||
| MS_LOG(EXCEPTION) << "Failed to get the node's[ " << node->DebugString() << "] reshape type !"; | |||
| if (build_info->IsInputDefaultPadding()) { | |||
| return {}; | |||
| } | |||
| return result; | |||
| return build_info->GetInputReshapeType(input_idx); | |||
| } | |||
| std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (output_idx > GetOutputTensorNum(node)) { | |||
| MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " | |||
| << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]"; | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(build_info); | |||
| std::vector<kernel::Axis> result; | |||
| if (!build_info->GetOutputReshapeType(output_idx, &result)) { | |||
| MS_LOG(EXCEPTION) << "Failed to get the node's[ " << node->DebugString() << "] reshape type !"; | |||
| if (build_info->IsOutputDefaultPadding()) { | |||
| return {}; | |||
| } | |||
| return result; | |||
| return build_info->GetOutputReshapeType(output_idx); | |||
| } | |||
| TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) { | |||
| @@ -465,6 +473,10 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &nod | |||
| TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (output_idx > GetOutputTensorNum(node)) { | |||
| MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " | |||
| << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]"; | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| @@ -474,6 +486,10 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size | |||
| TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (input_idx > GetInputTensorNum(node)) { | |||
| MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ " | |||
| << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]"; | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| @@ -498,11 +514,15 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, | |||
| MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"; | |||
| } | |||
| } | |||
| if (output_idx > GetOutputTensorNum(node)) { | |||
| MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " | |||
| << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto addr = kernel_info->GetOutputAddr(output_idx); | |||
| if (addr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "output_idx " << output_idx << " of node " << node->DebugString() | |||
| MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString() | |||
| << " output addr is not exist"; | |||
| } | |||
| return addr; | |||
| @@ -519,11 +539,15 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod | |||
| MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."; | |||
| } | |||
| } | |||
| if (output_idx > GetOutputTensorNum(node)) { | |||
| MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " | |||
| << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto addr = kernel_info->GetMutableOutputAddr(output_idx); | |||
| if (addr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "output_idx" << output_idx << " of node " << node->DebugString() | |||
| MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString() | |||
| << " output addr is not exist"; | |||
| } | |||
| return addr; | |||
| @@ -532,6 +556,10 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod | |||
| // get output device addr of anf_node | |||
| bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (output_idx > GetOutputTensorNum(node)) { | |||
| MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " | |||
| << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| return kernel_info->OutputAddrExist(output_idx); | |||
| @@ -771,22 +799,24 @@ AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) | |||
| return node->input(get_input_index); | |||
| } | |||
| bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->isa<ValueNode>()) { | |||
| return false; | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| return kernel_info->is_feature_map(); | |||
| } | |||
| bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) { | |||
| if (!node->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature"; | |||
| MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map"; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto input_node = cnode->input(input_index + 1); | |||
| auto node_with_index = VisitKernel(input_node, 0); | |||
| MS_EXCEPTION_IF_NULL(node_with_index.first); | |||
| if (node_with_index.first->isa<ValueNode>()) { | |||
| return false; | |||
| } | |||
| if (node_with_index.first->isa<Parameter>()) { | |||
| return !AnfAlgo::IsParameterWeight(node_with_index.first->cast<ParameterPtr>()); | |||
| } | |||
| return true; | |||
| return IsFeatureMapOutput(input_node); | |||
| } | |||
| size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) { | |||
| @@ -102,7 +102,9 @@ class AnfRuntimeAlgorithm { | |||
| static std::vector<size_t> GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx); | |||
| // get input shapes which will built and run in device | |||
| static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx); | |||
| // Get Input Padding Axis | |||
| static std::vector<kernel::Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx); | |||
| // Get Output Padding Axis | |||
| static std::vector<kernel::Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx); | |||
| // get output data type inferred by ME of anf node | |||
| static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx); | |||
| @@ -166,6 +168,9 @@ class AnfRuntimeAlgorithm { | |||
| // get graph id | |||
| static uint32_t GetGraphId(const AnfNode *node); | |||
| static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index); | |||
| // charge if the node's output is a feature map output | |||
| static bool IsFeatureMapOutput(const AnfNodePtr &node); | |||
| // charge if the node's input is from a feature map output | |||
| static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index); | |||
| // get real input index for some tbe ops which input order is different between me and tbe impl | |||
| static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); | |||
| @@ -18,6 +18,7 @@ | |||
| #include "operator/ops.h" | |||
| #include "ir/meta_tensor.h" | |||
| #include "ir/anf.h" | |||
| #include "common/trans.h" | |||
| #include "device/kernel_runtime.h" | |||
| #include "device/ascend/kernel_select_ascend.h" | |||
| #include "device/ascend/kernel_build_ascend.h" | |||
| @@ -730,8 +731,8 @@ void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor | |||
| size_t tensor_size = front_tensor->data().nbytes(); | |||
| auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0); | |||
| MS_EXCEPTION_IF_NULL(addr); | |||
| if (!addr->SyncHostToDevice(front_tensor->shape(), tensor_size, front_tensor->data_type(), | |||
| front_tensor->data_c(false))) { | |||
| if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size, | |||
| front_tensor->data_type(), front_tensor->data_c(false))) { | |||
| MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!"; | |||
| } | |||
| MS_LOG(INFO) << "Finish!"; | |||
| @@ -143,6 +143,12 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | |||
| cnode->set_abstract(std::make_shared<abstract::AbstractNone>()); | |||
| // create kernel_info from new parameter | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| // if the node only has the primitive(such as getNext) or the node's input has a feature map input | |||
| // then the node's output is a feature map output | |||
| if (inputs.size() == 1 || std::any_of(inputs.begin() + 1, inputs.end(), | |||
| [&](const AnfNodePtr &node) { return AnfAlgo::IsFeatureMapOutput(node); })) { | |||
| kernel_info->SetFeatureMapFlag(true); | |||
| } | |||
| cnode->set_kernel_info(kernel_info); | |||
| AnfAlgo::SetGraphId(graph_id_, cnode.get()); | |||
| return cnode; | |||
| @@ -162,22 +168,26 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) { | |||
| ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { | |||
| ParameterPtr new_parameter = add_parameter(); | |||
| MS_EXCEPTION_IF_NULL(new_parameter); | |||
| // create kernel_info form new parameter | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| size_t output_tensor_num = 1; | |||
| // if use default parameter = nullptr,it remarks create a new parameter from no parameter | |||
| if (parameter == nullptr) { | |||
| new_parameter->set_abstract(std::make_shared<abstract::AbstractNone>()); | |||
| kernel_info->SetFeatureMapFlag(true); | |||
| } else { | |||
| // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter | |||
| new_parameter->set_abstract(parameter->abstract()); | |||
| new_parameter->set_name(parameter->name()); | |||
| if (parameter->has_default()) { | |||
| if (AnfAlgo::IsParameterWeight(parameter)) { | |||
| new_parameter->set_default_param(parameter->default_param()); | |||
| kernel_info->SetFeatureMapFlag(false); | |||
| } else { | |||
| kernel_info->SetFeatureMapFlag(true); | |||
| } | |||
| // if output is a tuple tensor,now can use for loop to handle tuple tensor | |||
| output_tensor_num = AnfAlgo::GetOutputTensorNum(parameter); | |||
| } | |||
| // create kernel_info form new parameter | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| new_parameter->set_kernel_info(kernel_info); | |||
| // create kernel_build_info for new parameter | |||
| auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| @@ -217,6 +227,7 @@ std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNo | |||
| AddValueNodeToGraph(new_value_node); | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| new_value_node->set_kernel_info(kernel_info); | |||
| kernel_info->SetFeatureMapFlag(false); | |||
| // create kernel_build_info for new value node | |||
| auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| // set the format of value_node to DEFAULT_FORMAT | |||
| @@ -240,6 +251,7 @@ ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) { | |||
| new_value_node->set_abstract(value_node->abstract()); | |||
| // create kernel_info fo new value node | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| kernel_info->SetFeatureMapFlag(false); | |||
| new_value_node->set_kernel_info(kernel_info); | |||
| // create kernel_build_info for new value node | |||
| auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| @@ -20,6 +20,7 @@ | |||
| #include "pipeline/parse/data_converter.h" | |||
| #include "ir/manager.h" | |||
| #include "operator/ops.h" | |||
| #include "common/trans.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "utils/config_manager.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| @@ -124,7 +125,8 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->enable_pynative_infer()) { | |||
| tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index)); | |||
| } else if (!address->SyncDeviceToHost(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), | |||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| tensor->data_c(true))) { | |||
| MS_LOG(INFO) << "output sync device to host error!!!"; | |||
| tensor->set_dirty(false); | |||
| @@ -369,7 +371,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, | |||
| kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{input_tensor->device_address()->type_id()}); | |||
| } | |||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get()); | |||
| // construct abstract of parameter | |||
| // ftruct abstract of parameter | |||
| auto abstract = std::make_shared<abstract::AbstractTensor>(input_tensor); | |||
| param->set_abstract(abstract); | |||
| return param; | |||
| @@ -548,7 +550,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap | |||
| if (need_sync) { | |||
| tensor->set_device_address(device_address); | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| if (!device_address->SyncHostToDevice(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), | |||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| tensor->data_c(false))) { | |||
| MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; | |||
| } | |||
| @@ -620,8 +623,8 @@ void SessionBasic::Summary(KernelGraph *graph) { | |||
| (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); | |||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| if (!address->SyncDeviceToHost(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| tensor->data_c(true))) { | |||
| if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()), | |||
| tensor->data_type(), tensor->data_c(true))) { | |||
| MS_LOG(ERROR) << "Failed to sync output from device to host."; | |||
| } | |||
| tensor->set_dirty(false); | |||
| @@ -197,8 +197,8 @@ const std::set<std::string> kOptOperatorSet = { | |||
| kApplyRMSPropOpName, | |||
| }; | |||
| const std::set<std::string> kSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, | |||
| kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0}; | |||
| const std::set<std::string> kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, | |||
| kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0}; | |||
| static inline void ChangeFileMode(const std::string& file_name, mode_t mode) { | |||
| if (access(file_name.c_str(), F_OK) != 0) { | |||
| @@ -80,6 +80,8 @@ TEST_F(TestHWLayerNormBetaGammaBackpropFusion, layernorm_beta_gamma_backprop_fus | |||
| builder1.SetOutputsDeviceType({kNumberTypeFloat32}); | |||
| cast0->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| cast1->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| cast0->set_abstract(x_abstract); | |||
| cast1->set_abstract(x_abstract); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast0.get()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast1.get()); | |||
| @@ -211,8 +211,8 @@ TEST_F(AnfRuntimeAlgorithmTest, EraseNodeAttr) { | |||
| TEST_F(AnfRuntimeAlgorithmTest, GetInputTensorNum) { | |||
| auto kernel_graph = std::make_shared<KernelGraph>(); | |||
| // test cnode node | |||
| auto parameter_one = kernel_graph->add_parameter(); | |||
| auto parameter_two = kernel_graph->add_parameter(); | |||
| auto parameter_one = kernel_graph->NewParameter(); | |||
| auto parameter_two = kernel_graph->NewParameter(); | |||
| std::vector<AnfNodePtr> add_inputs{NewValueNode(prim::kPrimTensorAdd), parameter_one, parameter_two}; | |||
| auto add = kernel_graph->NewCNode(add_inputs); | |||
| EXPECT_EQ(AnfAlgo::GetInputTensorNum(add), 2); | |||
| @@ -247,9 +247,11 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputTensorNum) { | |||
| TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) { | |||
| auto kernel_graph = std::make_shared<KernelGraph>(); | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.push_back(NewValueNode(prim::kPrimTensorAdd)); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim::kPrimTensorAdd), kernel_graph->NewParameter(), | |||
| kernel_graph->NewParameter()}; | |||
| auto add = kernel_graph->NewCNode(inputs); | |||
| std::vector<size_t> shape = {1, 2, 3, 4}; | |||
| AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, {shape, shape}, add.get()); | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto d_kernel_info = add->kernel_info(); | |||
| @@ -266,8 +268,8 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) { | |||
| TEST_F(AnfRuntimeAlgorithmTest, GetInputFormat) { | |||
| auto kernel_graph = std::make_shared<KernelGraph>(); | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.push_back(NewValueNode(prim::kPrimTensorAdd)); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim::kPrimTensorAdd), kernel_graph->NewParameter(), | |||
| kernel_graph->NewParameter()}; | |||
| auto add = kernel_graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| @@ -345,7 +347,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputInferShape) { | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| // test parameter node as input | |||
| auto parameter_node = kernel_graph->add_parameter(); | |||
| auto parameter_node = kernel_graph->NewParameter(); | |||
| MS_EXCEPTION_IF_NULL(parameter_node); | |||
| parameter_node->set_abstract(x_abstract); | |||
| EXPECT_THROW(AnfAlgo::GetPrevNodeOutputInferShape(parameter_node, 0), std::runtime_error); | |||
| @@ -387,13 +389,13 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceShape) { | |||
| auto kernel_graph = std::make_shared<KernelGraph>(); | |||
| std::vector<int> shp{2, 32, 224, 224}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| auto parameter_one = kernel_graph->add_parameter(); | |||
| auto parameter_one = kernel_graph->NewParameter(); | |||
| MS_EXCEPTION_IF_NULL(parameter_one); | |||
| parameter_one->set_abstract(x_abstract); | |||
| auto parameter_two = kernel_graph->add_parameter(); | |||
| auto parameter_two = kernel_graph->NewParameter(); | |||
| MS_EXCEPTION_IF_NULL(parameter_two); | |||
| parameter_two->set_abstract(x_abstract); | |||
| auto parameter_third = kernel_graph->add_parameter(); | |||
| auto parameter_third = kernel_graph->NewParameter(); | |||
| MS_EXCEPTION_IF_NULL(parameter_third); | |||
| parameter_third->set_abstract(x_abstract); | |||
| // test cnode as input | |||
| @@ -466,8 +468,8 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceDataTypeTest) { | |||
| TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceDataTypeTest) { | |||
| auto kernel_graph = std::make_shared<KernelGraph>(); | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.push_back(NewValueNode(prim::kPrimTensorAdd)); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim::kPrimTensorAdd), kernel_graph->NewParameter(), | |||
| kernel_graph->NewParameter()}; | |||
| auto add = kernel_graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| @@ -140,11 +140,11 @@ TEST_F(KernelGraphTest, SetExecOrderByDefault) { | |||
| std::vector<int> shape = {2, 32, 224, 224}; | |||
| auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape); | |||
| auto x_parameter = kernel_graph->add_parameter(); | |||
| auto x_parameter = kernel_graph->NewParameter(); | |||
| MS_EXCEPTION_IF_NULL(x_parameter); | |||
| x_parameter->set_name("x_parameter"); | |||
| x_parameter->set_abstract(abstract); | |||
| auto y_parameter = kernel_graph->add_parameter(); | |||
| auto y_parameter = kernel_graph->NewParameter(); | |||
| MS_EXCEPTION_IF_NULL(y_parameter); | |||
| y_parameter->set_name("y_parameter"); | |||
| y_parameter->set_abstract(abstract); | |||
| @@ -153,7 +153,7 @@ TEST_F(KernelGraphTest, SetExecOrderByDefault) { | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| add->set_abstract(abstract); | |||
| auto z_parameter = kernel_graph->add_parameter(); | |||
| auto z_parameter = kernel_graph->NewParameter(); | |||
| MS_EXCEPTION_IF_NULL(z_parameter); | |||
| z_parameter->set_name("z_parameter"); | |||
| z_parameter->set_abstract(abstract); | |||