| @@ -62,7 +62,7 @@ class TransposeGpuFwdKernel : public GpuKernel { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output."; | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| shape_size_ = input_shape.size(); | |||
| if (shape_size_ > TRANSPOSE_MAX_DIMENSION) { | |||
| MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but transpose supports max " << TRANSPOSE_MAX_DIMENSION | |||
| @@ -52,6 +52,8 @@ TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const { | |||
| return outputs_device_type_[output_index]; | |||
| } | |||
| const std::string &KernelBuildInfo::GetOriginDataFormat() const { return origin_data_format_; } | |||
| const std::vector<std::string> &KernelBuildInfo::GetAllInputFormats() const { return inputs_format_; } | |||
| const std::vector<std::string> &KernelBuildInfo::GetAllOutputFormats() const { return outputs_format_; } | |||
| @@ -132,6 +134,11 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &ke | |||
| kernel_build_info_->kernel_type_ = kernel_type; | |||
| } | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetOriginDataFormat(const std::string &origin_data_format) { | |||
| MS_EXCEPTION_IF_NULL(kernel_build_info_); | |||
| kernel_build_info_->origin_data_format_ = origin_data_format; | |||
| } | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsFormat(const std::vector<std::string> &inputs_format) { | |||
| MS_EXCEPTION_IF_NULL(kernel_build_info_); | |||
| kernel_build_info_->inputs_format_ = inputs_format; | |||
| @@ -38,6 +38,7 @@ class KernelBuildInfo { | |||
| op_pattern_ = kCommonPattern; | |||
| input_reshape_type_ = {}; | |||
| output_reshape_type_ = {}; | |||
| origin_data_format_ = kOpFormat_DEFAULT; | |||
| inputs_format_ = {}; | |||
| outputs_format_ = {}; | |||
| inputs_device_type_ = {}; | |||
| @@ -64,6 +65,8 @@ class KernelBuildInfo { | |||
| std::vector<Axis> GetOutputReshapeType(size_t input_index) const; | |||
| const std::string &GetOriginDataFormat() const; | |||
| const std::vector<std::string> &GetAllInputFormats() const; | |||
| const std::vector<std::string> &GetAllOutputFormats() const; | |||
| @@ -97,6 +100,7 @@ class KernelBuildInfo { | |||
| private: | |||
| KernelType kernel_type_; | |||
| std::string origin_data_format_; | |||
| std::vector<std::string> inputs_format_; | |||
| OpPattern op_pattern_; | |||
| std::vector<std::string> outputs_format_; | |||
| @@ -135,6 +139,8 @@ class KernelBuildInfo::KernelBuildInfoBuilder { | |||
| void SetKernelType(const KernelType &kernel_type); | |||
| void SetOriginDataFormat(const std::string &origin_data_format); | |||
| void SetInputsFormat(const std::vector<std::string> &inputs_format); | |||
| void SetOutputsFormat(const std::vector<std::string> &outputs_format); | |||
| @@ -506,6 +506,45 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con | |||
| return output_node_list; | |||
| } | |||
| std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph, | |||
| const AnfNodePtr &node, | |||
| size_t output_index) { | |||
| auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>(); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto iter = manager->node_users().find(node); | |||
| if (iter == manager->node_users().end()) { | |||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | |||
| } | |||
| auto output_info_list = iter->second; | |||
| for (const auto &output_info : output_info_list) { | |||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && | |||
| output_info.second == kDependAttachNodeIndex) { | |||
| continue; | |||
| } | |||
| size_t used_output_index; | |||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimTupleGetItem->name()) { | |||
| used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first)); | |||
| } else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) { | |||
| used_output_index = output_index; | |||
| } else { | |||
| auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(output_info.first, output_info.second - 1); | |||
| if (kernel_with_index.first.get() != node.get()) { | |||
| MS_LOG(EXCEPTION) << "Get used node failed for op[" << AnfAlgo::GetCNodeName(node) << "]"; | |||
| } | |||
| used_output_index = kernel_with_index.second; | |||
| } | |||
| if (used_output_index == output_index) { | |||
| output_node_list->push_back(output_info); | |||
| } | |||
| } | |||
| return output_node_list; | |||
| } | |||
| bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| @@ -172,6 +172,10 @@ bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); | |||
| std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph, | |||
| const AnfNodePtr &node); | |||
| std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph, | |||
| const AnfNodePtr &node, | |||
| size_t output_index); | |||
| void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs); | |||
| bool AnfEqual(const BaseRef &a, const BaseRef &b); | |||
| @@ -0,0 +1,151 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/gpu/insert_format_transform_op.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/utils.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "runtime/device/gpu/kernel_info_setter.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| std::vector<int> TransposeAxis(const std::string &src_format, const std::string &dst_format) { | |||
| if ((src_format == kOpFormat_NCHW) && (dst_format == kOpFormat_NHWC)) { | |||
| return {0, 2, 3, 1}; | |||
| } else if ((src_format == kOpFormat_NHWC) && (dst_format == kOpFormat_NCHW)) { | |||
| return {0, 3, 1, 2}; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Invaild format transform, from " << src_format << " to " << dst_format; | |||
| } | |||
| } | |||
| void SetTransposeOpBuildInfo(const std::string &input_format, const std::string &output_format, | |||
| const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto input_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); | |||
| auto output_type = AnfAlgo::GetOutputInferDataType(node, 0); | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| builder.SetInputsFormat({input_format}); | |||
| builder.SetInputsDeviceType({input_type}); | |||
| builder.SetOutputsFormat({output_format}); | |||
| builder.SetOutputsDeviceType({output_type}); | |||
| builder.SetKernelType(UNKNOWN_KERNEL_TYPE); | |||
| builder.SetProcessor(kernel::Processor::CUDA); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get()); | |||
| } | |||
| // Insert transpose op between node and used_node whose position is used_node_index. | |||
| CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node, | |||
| int used_node_index, const std::vector<int> &transpose_perm) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| // 1.Create a transpose node. | |||
| auto transpose_prim = std::make_shared<Primitive>(prim::kPrimTranspose->name()); | |||
| MS_EXCEPTION_IF_NULL(transpose_prim); | |||
| // 2.Set the input of transpose. | |||
| std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node}; | |||
| auto transpose_op = graph->NewCNode(transpose_input); | |||
| // 3.Set the output info of transpose. | |||
| auto transpose_type = {AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)}; | |||
| auto transpose_shape = {AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index)}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(transpose_type, transpose_shape, transpose_op.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op); | |||
| // 4.Set the input of used_node. | |||
| MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope() | |||
| << ", index: " << used_node_index; | |||
| AnfAlgo::SetNodeInput(utils::cast<CNodePtr>(used_node), transpose_op, used_node_index); | |||
| // 5. Update the manager info of transpose op. | |||
| FuncGraphManagerPtr manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->Clear(); | |||
| manager->AddFuncGraph(graph); | |||
| return transpose_op; | |||
| } | |||
| } // namespace | |||
| const AnfNodePtr InsertFormatTransformOp::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| if (!AnfAlgo::IsRealCNodeKernel(node)) { | |||
| return nullptr; | |||
| } | |||
| auto iter = device::gpu::kKernelFormatPositionMap.find(AnfAlgo::GetCNodeName(node)); | |||
| if (iter == device::gpu::kKernelFormatPositionMap.end()) { | |||
| return nullptr; | |||
| } | |||
| auto origin_data_format = AnfAlgo::GetOriginDataFormat(node); | |||
| if (origin_data_format == kOpFormat_DEFAULT) { | |||
| origin_data_format = kOpFormat_NCHW; | |||
| } | |||
| MS_LOG(DEBUG) << "Process node: " << node->fullname_with_scope(); | |||
| // Insert input transpose from origin_data_format to input_format. | |||
| auto inputs_format = AnfAlgo::GetAllInputFormats(node); | |||
| for (size_t i = 0; i < inputs_format.size(); i++) { | |||
| if ((inputs_format[i] != kOpFormat_DEFAULT) && (inputs_format[i] != origin_data_format)) { | |||
| auto input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i); | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| auto transpose_perm = TransposeAxis(origin_data_format, inputs_format[i]); | |||
| auto transpose_op = InsertTransposeOp(graph, input_node, node, i, transpose_perm); | |||
| SetTransposeOpBuildInfo(kOpFormat_DEFAULT, inputs_format[i], transpose_op); | |||
| } | |||
| } | |||
| // Insert output transpose from output_format to origin_data_format. | |||
| auto outputs_format = AnfAlgo::GetAllOutputFormats(node); | |||
| for (size_t i = 0; i < outputs_format.size(); i++) { | |||
| if ((outputs_format[i] != kOpFormat_DEFAULT) && (outputs_format[i] != origin_data_format)) { | |||
| // Find all nodes connected with node output, and change their inputs to transpose. | |||
| auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, i); | |||
| for (size_t j = 0; j < used_node_list->size(); j++) { | |||
| auto used_node = used_node_list->at(j).first; | |||
| auto used_node_index = used_node_list->at(j).second - 1; | |||
| auto transpose_perm = TransposeAxis(outputs_format[i], origin_data_format); | |||
| if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTupleGetItem->name()) { | |||
| MS_LOG(DEBUG) << "The used node of [" << node->fullname_with_scope() << "] is tuple item."; | |||
| // The tuple item need get next used nodes again. | |||
| ProcessForTupleItem(graph, used_node, used_node_index, transpose_perm, outputs_format[i]); | |||
| continue; | |||
| } | |||
| auto transpose_op = InsertTransposeOp(graph, node, used_node, used_node_index, transpose_perm); | |||
| SetTransposeOpBuildInfo(outputs_format[i], kOpFormat_DEFAULT, transpose_op); | |||
| } | |||
| } | |||
| } | |||
| return node; | |||
| } | |||
| void InsertFormatTransformOp::ProcessForTupleItem(const FuncGraphPtr &graph, const AnfNodePtr &node, int node_index, | |||
| const std::vector<int> &transpose_perm, | |||
| const std::string &transpose_format) const { | |||
| auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, node_index); | |||
| for (size_t i = 0; i < used_node_list->size(); i++) { | |||
| auto used_node = used_node_list->at(i).first; | |||
| auto used_node_index = used_node_list->at(i).second - 1; | |||
| if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTupleGetItem->name()) { | |||
| MS_LOG(EXCEPTION) << "The used node of tuple item can't be tuple item."; | |||
| } | |||
| auto transpose_op = InsertTransposeOp(graph, node, used_node, used_node_index, transpose_perm); | |||
| SetTransposeOpBuildInfo(transpose_format, kOpFormat_DEFAULT, transpose_op); | |||
| } | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_INSERT_FORMAT_TRANSFORM_OP_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_INSERT_FORMAT_TRANSFORM_OP_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class InsertFormatTransformOp : public PatternProcessPass { | |||
| public: | |||
| explicit InsertFormatTransformOp(bool multigraph = true) | |||
| : PatternProcessPass("insert_format_transform_op", multigraph) {} | |||
| ~InsertFormatTransformOp() override = default; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| void ProcessForTupleItem(const FuncGraphPtr &graph, const AnfNodePtr &node, int node_index, | |||
| const std::vector<int> &transpose_perm, const std::string &transpose_format) const; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_INSERT_FORMAT_TRANSFORM_OP_H_ | |||
| @@ -0,0 +1,65 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/gpu/remove_format_transform_pair.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/utils.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef RemoveFormatTransformPair::DefinePattern() const { | |||
| VarPtr X = std::make_shared<Var>(); | |||
| MS_EXCEPTION_IF_NULL(X); | |||
| VectorRef transpose1 = VectorRef({prim::kPrimTranspose, X}); | |||
| VectorRef transpose2 = VectorRef({prim::kPrimTranspose, transpose1}); | |||
| return transpose2; | |||
| } | |||
| const AnfNodePtr RemoveFormatTransformPair::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| MS_LOG(DEBUG) << "Process node:" << node->fullname_with_scope(); | |||
| auto input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0); | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (AnfAlgo::GetCNodeName(node) != prim::kPrimTranspose->name() || | |||
| AnfAlgo::GetCNodeName(input_node) != prim::kPrimTranspose->name()) { | |||
| MS_LOG(EXCEPTION) << "The pattern is not transpose pair, " | |||
| << "node:" << AnfAlgo::GetCNodeName(node) << " node input:" << AnfAlgo::GetCNodeName(input_node); | |||
| } | |||
| // If transpose operator used by more than one other operators, it cant not be deleted directly. | |||
| if (IsUsedByOthers(graph, input_node)) { | |||
| MS_LOG(DEBUG) << "The transpose node [" << input_node->fullname_with_scope() | |||
| << "] is used by more than one other operators."; | |||
| return nullptr; | |||
| } | |||
| auto transpose1_input_shape = AnfAlgo::GetInputDeviceShape(input_node, 0); | |||
| auto transpose2_output_shape = AnfAlgo::GetOutputDeviceShape(node, 0); | |||
| if (transpose2_output_shape == transpose1_input_shape) { | |||
| auto transpose1_input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(input_node), 0); | |||
| MS_EXCEPTION_IF_NULL(transpose1_input_node); | |||
| return transpose1_input_node; | |||
| } | |||
| return nullptr; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REMOVE_FORMAT_TRANSFORM_PAIR_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REMOVE_FORMAT_TRANSFORM_PAIR_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class RemoveFormatTransformPair : public PatternProcessPass { | |||
| public: | |||
| explicit RemoveFormatTransformPair(bool multigraph = true) | |||
| : PatternProcessPass("remove_format_transform_pair", multigraph) {} | |||
| ~RemoveFormatTransformPair() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REMOVE_FORMAT_TRANSFORM_PAIR_H_ | |||
| @@ -353,6 +353,48 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) { | |||
| } | |||
| } | |||
| std::vector<std::string> AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!AnfAlgo::IsRealKernel(node)) { | |||
| MS_LOG(EXCEPTION) << "Not real kernel:" | |||
| << "#node [" << node->DebugString() << "]"; | |||
| } | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(build_info); | |||
| auto format = build_info->GetAllOutputFormats(); | |||
| return format; | |||
| } | |||
| std::vector<std::string> AnfRuntimeAlgorithm::GetAllInputFormats(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!AnfAlgo::IsRealKernel(node)) { | |||
| MS_LOG(EXCEPTION) << "Not real kernel:" | |||
| << "#node [" << node->DebugString() << "]"; | |||
| } | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(build_info); | |||
| auto format = build_info->GetAllInputFormats(); | |||
| return format; | |||
| } | |||
| std::string AnfRuntimeAlgorithm::GetOriginDataFormat(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!AnfAlgo::IsRealKernel(node)) { | |||
| MS_LOG(EXCEPTION) << "Not real kernel:" | |||
| << "#node [" << node->DebugString() << "]"; | |||
| } | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(build_info); | |||
| auto format = build_info->GetOriginDataFormat(); | |||
| return format; | |||
| } | |||
| std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (output_idx > GetOutputTensorNum(node)) { | |||
| @@ -829,7 +871,7 @@ void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode * | |||
| bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| // parameter and value node is not a real kernel too | |||
| // parameter and value node is a real kernel too | |||
| if (!node->isa<CNode>()) { | |||
| return true; | |||
| } | |||
| @@ -101,6 +101,12 @@ class AnfRuntimeAlgorithm { | |||
| static size_t GetInputTensorNum(const AnfNodePtr &node); | |||
| // get the num of output real_kernel(which can be build and run in device) | |||
| static size_t GetOutputTensorNum(const AnfNodePtr &node); | |||
| // get all outputs format select of anf node | |||
| static std::vector<std::string> GetAllOutputFormats(const AnfNodePtr &node); | |||
| // get all inputs format select of anf node | |||
| static std::vector<std::string> GetAllInputFormats(const AnfNodePtr &node); | |||
| // get origin data format select of anf node | |||
| static std::string GetOriginDataFormat(const AnfNodePtr &node); | |||
| // get output format select of anf node | |||
| static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx); | |||
| // get input format select of anf node | |||
| @@ -30,6 +30,8 @@ | |||
| #include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h" | |||
| #include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" | |||
| #include "backend/optimizer/gpu/replace_addn_fusion.h" | |||
| #include "backend/optimizer/gpu/insert_format_transform_op.h" | |||
| #include "backend/optimizer/gpu/remove_format_transform_pair.h" | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "common/trans.h" | |||
| @@ -76,6 +78,8 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>()); | |||
| pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>()); | |||
| pm->AddPass(std::make_shared<opt::AllReduceFusion>()); | |||
| pm->AddPass(std::make_shared<opt::GetitemTuple>()); | |||
| optimizer->AddPassManager(pm); | |||
| @@ -203,7 +207,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList | |||
| } | |||
| // Assign CUDA streams | |||
| AssignStream(graph); | |||
| // Hide NoOp from execution graph | |||
| // Hide NopOp from execution graph | |||
| opt::HideNopNode(graph.get()); | |||
| // Build kernel if node is cnode | |||
| BuildKernel(graph); | |||
| @@ -213,7 +217,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList | |||
| graph->set_execution_order(execution_order); | |||
| // Get summary nodes. | |||
| SetSummaryNodes(graph.get()); | |||
| // Remove NoOp from execution graph | |||
| // Remove NopOp from execution graph | |||
| opt::RemoveNopNode(graph.get()); | |||
| // Set graph manager. | |||
| MS_EXCEPTION_IF_NULL(context_); | |||
| @@ -272,7 +276,7 @@ void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_in | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| SelectKernel(kernel_graph); | |||
| StartKernelRT(); | |||
| // Hide NoOp from execution graph | |||
| // Hide NopOp from execution graph | |||
| opt::HideNopNode(kernel_graph.get()); | |||
| BuildKernel(kernel_graph); | |||
| run_op_graphs_[graph_info] = kernel_graph; | |||
| @@ -282,7 +286,7 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph | |||
| const std::vector<tensor::TensorPtr> &input_tensors) { | |||
| auto kernel_graph = run_op_graphs_[graph_info]; | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| // Remove NoOp from execution graph | |||
| // Remove NopOp from execution graph | |||
| opt::RemoveNopNode(kernel_graph.get()); | |||
| RunOpAllocateMemory(op_run_info.value, input_tensors, kernel_graph.get()); | |||
| // Execute the computation | |||
| @@ -252,7 +252,7 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "invalid input Data, please check it."); | |||
| } else { | |||
| if (!stop_send_) { | |||
| MS_LOG(WARNING) << "Retry pushing data..."; | |||
| MS_LOG(DEBUG) << "Retry pushing data..."; | |||
| continue; | |||
| } | |||
| break; | |||
| @@ -96,10 +96,10 @@ BlockQueueStatus_T BlockingQueue::Create(void *addr, const std::vector<size_t> & | |||
| void BlockingQueue::RegisterRelease(const std::function<void(void *)> &func) { queue_->RegisterRelease(func); } | |||
| BlockQueueStatus_T BlockingQueue::Push(const std::vector<DataItemGpu> &data, unsigned int timeout_in_sec) { | |||
| BlockQueueStatus_T BlockingQueue::Push(const std::vector<DataItemGpu> &data, unsigned int) { | |||
| std::unique_lock<std::mutex> locker(mutex_); | |||
| if (queue_->IsFull()) { | |||
| if (not_full_cond_.wait_for(locker, std::chrono::seconds(timeout_in_sec)) == std::cv_status::timeout) { | |||
| if (not_full_cond_.wait_for(locker, std::chrono::microseconds(100)) == std::cv_status::timeout) { | |||
| return TIMEOUT; | |||
| } | |||
| } | |||
| @@ -19,6 +19,7 @@ | |||
| #include <memory> | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "utils/utils.h" | |||
| #include "utils/ms_context.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/kernel_build_info.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| @@ -157,25 +158,87 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co | |||
| } | |||
| } | |||
| } | |||
| bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type) { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->execution_mode() == kPynativeMode) { | |||
| return false; | |||
| } | |||
| if (!AnfAlgo::IsRealCNodeKernel(kernel_node)) { | |||
| return false; | |||
| } | |||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| auto iter = kKernelFormatPositionMap.find(kernel_name); | |||
| if (iter == kKernelFormatPositionMap.end()) { | |||
| return false; | |||
| } | |||
| if (inputs_type.size() == 0) { | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (input_shape.size() != 4) { | |||
| return false; | |||
| } | |||
| return false; | |||
| } | |||
| void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type, | |||
| std::vector<std::string> *inputs_format, std::vector<std::string> *outputs_format, | |||
| std::string *origin_data_format) { | |||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| auto iter = kKernelFormatPositionMap.find(kernel_name); | |||
| if (iter == kKernelFormatPositionMap.end()) { | |||
| return; | |||
| } | |||
| auto cal_format = (inputs_type[0] == kNumberTypeFloat16) ? kOpFormat_NHWC : kOpFormat_NCHW; | |||
| MS_LOG(DEBUG) << "Kernel node: " << kernel_node->fullname_with_scope() << ", format: " << cal_format; | |||
| auto inputs_format_position = iter->second.first; | |||
| for (const auto &input_format_position : inputs_format_position) { | |||
| if (input_format_position >= inputs_format->size()) { | |||
| MS_LOG(EXCEPTION) << "The position [" << input_format_position << "] is out of range of the input size [" | |||
| << inputs_format->size() << "] #kernel_node [" << kernel_node->fullname_with_scope() << "]"; | |||
| } | |||
| (*inputs_format)[input_format_position] = cal_format; | |||
| } | |||
| auto outputs_format_position = iter->second.second; | |||
| for (const auto &output_format_position : outputs_format_position) { | |||
| if (output_format_position >= outputs_format->size()) { | |||
| MS_LOG(EXCEPTION) << "The position [" << output_format_position << "] is out of range of the output size [" | |||
| << outputs_format->size() << "] #kernel_node [" << kernel_node->fullname_with_scope() << "]"; | |||
| } | |||
| (*outputs_format)[output_format_position] = cal_format; | |||
| } | |||
| auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| if (prim->HasAttr("data_format")) { | |||
| *origin_data_format = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "data_format"); | |||
| } | |||
| } | |||
| } // namespace | |||
| void SetKernelInfo(const CNodePtr &kernel_node) { | |||
| std::vector<std::string> inputs_format; | |||
| std::vector<TypeId> inputs_type; | |||
| std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder = | |||
| std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||
| inputs_format.emplace_back(kOpFormat_DEFAULT); | |||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); | |||
| } | |||
| builder->SetInputsFormat(inputs_format); | |||
| builder->SetInputsDeviceType(inputs_type); | |||
| std::vector<std::string> outputs_format; | |||
| std::vector<TypeId> outputs_type; | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { | |||
| outputs_format.emplace_back(kOpFormat_DEFAULT); | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); | |||
| } | |||
| std::string origin_data_format = kOpFormat_DEFAULT; | |||
| if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) { | |||
| UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format); | |||
| } | |||
| std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder = | |||
| std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| builder->SetOriginDataFormat(origin_data_format); | |||
| builder->SetInputsFormat(inputs_format); | |||
| builder->SetInputsDeviceType(inputs_type); | |||
| builder->SetOutputsFormat(outputs_format); | |||
| builder->SetOutputsDeviceType(outputs_type); | |||
| @@ -20,13 +20,35 @@ | |||
| #include <utility> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "ir/anf.h" | |||
| #include "ir/dtype.h" | |||
| #include "utils/utils.h" | |||
| #include "frontend/operator/ops.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace gpu { | |||
| // map<opName, (inputFormatPosition, outputFormatPosition)>, used for getting the insert position of format transform. | |||
| static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>> kKernelFormatPositionMap = { | |||
| {prim::kPrimConv2D->name(), {{0, 1}, {0}}}, | |||
| {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {0}}}, | |||
| {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {0}}}, | |||
| {prim::kPrimRelu->name(), {{0}, {0}}}, | |||
| {prim::kPrimReluGrad->name(), {{0, 1}, {0}}}, | |||
| {prim::kPrimMaxPool->name(), {{0}, {0}}}, | |||
| {prim::kPrimMaxPoolGrad->name(), {{0, 1, 2}, {0}}}, | |||
| {kAvgPoolOpName, {{0}, {0}}}, | |||
| {kAvgPoolGradGpuOpName, {{0, 1, 2}, {0}}}, | |||
| {kTensorAddOpName, {{0, 1}, {0}}}, | |||
| {kFusedBatchNormEx, {{0}, {0}}}, | |||
| {kFusedBatchNormExWithActivation, {{0}, {0}}}, | |||
| {kFusedBatchNormExWithAddAndActivation, {{0, 5}, {0}}}, | |||
| {kFusedBatchNormGradEx, {{0, 1}, {0}}}, | |||
| {kFusedBatchNormGradExWithActivation, {{0, 1, 7}, {0}}}, | |||
| {kFusedBatchNormGradExWithAddAndActivation, {{0, 1, 7}, {0, 3}}}, | |||
| }; | |||
| void SetKernelInfo(const CNodePtr &apply_kernel_ptr); | |||
| class KernelAttr { | |||
| @@ -189,6 +189,9 @@ constexpr auto kPullOpName = "Pull"; | |||
| constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup"; | |||
| constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy"; | |||
| constexpr auto kPaddingOpName = "Padding"; | |||
| constexpr auto kAvgPoolOpName = "AvgPool"; | |||
| constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; | |||
| constexpr auto kTensorAddOpName = "TensorAdd"; | |||
| // attr key name | |||
| constexpr auto kAttrInputNames = "input_names"; | |||