| @@ -62,7 +62,7 @@ class TransposeGpuFwdKernel : public GpuKernel { | |||||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output."; | MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||||
| shape_size_ = input_shape.size(); | shape_size_ = input_shape.size(); | ||||
| if (shape_size_ > TRANSPOSE_MAX_DIMENSION) { | if (shape_size_ > TRANSPOSE_MAX_DIMENSION) { | ||||
| MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but transpose supports max " << TRANSPOSE_MAX_DIMENSION | MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but transpose supports max " << TRANSPOSE_MAX_DIMENSION | ||||
| @@ -52,6 +52,8 @@ TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const { | |||||
| return outputs_device_type_[output_index]; | return outputs_device_type_[output_index]; | ||||
| } | } | ||||
| const std::string &KernelBuildInfo::GetOriginDataFormat() const { return origin_data_format_; } | |||||
| const std::vector<std::string> &KernelBuildInfo::GetAllInputFormats() const { return inputs_format_; } | const std::vector<std::string> &KernelBuildInfo::GetAllInputFormats() const { return inputs_format_; } | ||||
| const std::vector<std::string> &KernelBuildInfo::GetAllOutputFormats() const { return outputs_format_; } | const std::vector<std::string> &KernelBuildInfo::GetAllOutputFormats() const { return outputs_format_; } | ||||
| @@ -132,6 +134,11 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &ke | |||||
| kernel_build_info_->kernel_type_ = kernel_type; | kernel_build_info_->kernel_type_ = kernel_type; | ||||
| } | } | ||||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetOriginDataFormat(const std::string &origin_data_format) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_build_info_); | |||||
| kernel_build_info_->origin_data_format_ = origin_data_format; | |||||
| } | |||||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsFormat(const std::vector<std::string> &inputs_format) { | void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsFormat(const std::vector<std::string> &inputs_format) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_build_info_); | MS_EXCEPTION_IF_NULL(kernel_build_info_); | ||||
| kernel_build_info_->inputs_format_ = inputs_format; | kernel_build_info_->inputs_format_ = inputs_format; | ||||
| @@ -38,6 +38,7 @@ class KernelBuildInfo { | |||||
| op_pattern_ = kCommonPattern; | op_pattern_ = kCommonPattern; | ||||
| input_reshape_type_ = {}; | input_reshape_type_ = {}; | ||||
| output_reshape_type_ = {}; | output_reshape_type_ = {}; | ||||
| origin_data_format_ = kOpFormat_DEFAULT; | |||||
| inputs_format_ = {}; | inputs_format_ = {}; | ||||
| outputs_format_ = {}; | outputs_format_ = {}; | ||||
| inputs_device_type_ = {}; | inputs_device_type_ = {}; | ||||
| @@ -64,6 +65,8 @@ class KernelBuildInfo { | |||||
| std::vector<Axis> GetOutputReshapeType(size_t input_index) const; | std::vector<Axis> GetOutputReshapeType(size_t input_index) const; | ||||
| const std::string &GetOriginDataFormat() const; | |||||
| const std::vector<std::string> &GetAllInputFormats() const; | const std::vector<std::string> &GetAllInputFormats() const; | ||||
| const std::vector<std::string> &GetAllOutputFormats() const; | const std::vector<std::string> &GetAllOutputFormats() const; | ||||
| @@ -97,6 +100,7 @@ class KernelBuildInfo { | |||||
| private: | private: | ||||
| KernelType kernel_type_; | KernelType kernel_type_; | ||||
| std::string origin_data_format_; | |||||
| std::vector<std::string> inputs_format_; | std::vector<std::string> inputs_format_; | ||||
| OpPattern op_pattern_; | OpPattern op_pattern_; | ||||
| std::vector<std::string> outputs_format_; | std::vector<std::string> outputs_format_; | ||||
| @@ -135,6 +139,8 @@ class KernelBuildInfo::KernelBuildInfoBuilder { | |||||
| void SetKernelType(const KernelType &kernel_type); | void SetKernelType(const KernelType &kernel_type); | ||||
| void SetOriginDataFormat(const std::string &origin_data_format); | |||||
| void SetInputsFormat(const std::vector<std::string> &inputs_format); | void SetInputsFormat(const std::vector<std::string> &inputs_format); | ||||
| void SetOutputsFormat(const std::vector<std::string> &outputs_format); | void SetOutputsFormat(const std::vector<std::string> &outputs_format); | ||||
| @@ -506,6 +506,45 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con | |||||
| return output_node_list; | return output_node_list; | ||||
| } | } | ||||
| std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph, | |||||
| const AnfNodePtr &node, | |||||
| size_t output_index) { | |||||
| auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>(); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto iter = manager->node_users().find(node); | |||||
| if (iter == manager->node_users().end()) { | |||||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | |||||
| } | |||||
| auto output_info_list = iter->second; | |||||
| for (const auto &output_info : output_info_list) { | |||||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) { | |||||
| continue; | |||||
| } | |||||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && | |||||
| output_info.second == kDependAttachNodeIndex) { | |||||
| continue; | |||||
| } | |||||
| size_t used_output_index; | |||||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimTupleGetItem->name()) { | |||||
| used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first)); | |||||
| } else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) { | |||||
| used_output_index = output_index; | |||||
| } else { | |||||
| auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(output_info.first, output_info.second - 1); | |||||
| if (kernel_with_index.first.get() != node.get()) { | |||||
| MS_LOG(EXCEPTION) << "Get used node failed for op[" << AnfAlgo::GetCNodeName(node) << "]"; | |||||
| } | |||||
| used_output_index = kernel_with_index.second; | |||||
| } | |||||
| if (used_output_index == output_index) { | |||||
| output_node_list->push_back(output_info); | |||||
| } | |||||
| } | |||||
| return output_node_list; | |||||
| } | |||||
| bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { | bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| @@ -172,6 +172,10 @@ bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); | |||||
| std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph, | std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph, | ||||
| const AnfNodePtr &node); | const AnfNodePtr &node); | ||||
| std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph, | |||||
| const AnfNodePtr &node, | |||||
| size_t output_index); | |||||
| void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs); | void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs); | ||||
| bool AnfEqual(const BaseRef &a, const BaseRef &b); | bool AnfEqual(const BaseRef &a, const BaseRef &b); | ||||
| @@ -0,0 +1,151 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/gpu/insert_format_transform_op.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "ir/primitive.h" | |||||
| #include "utils/utils.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| #include "runtime/device/gpu/kernel_info_setter.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| std::vector<int> TransposeAxis(const std::string &src_format, const std::string &dst_format) { | |||||
| if ((src_format == kOpFormat_NCHW) && (dst_format == kOpFormat_NHWC)) { | |||||
| return {0, 2, 3, 1}; | |||||
| } else if ((src_format == kOpFormat_NHWC) && (dst_format == kOpFormat_NCHW)) { | |||||
| return {0, 3, 1, 2}; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Invaild format transform, from " << src_format << " to " << dst_format; | |||||
| } | |||||
| } | |||||
| void SetTransposeOpBuildInfo(const std::string &input_format, const std::string &output_format, | |||||
| const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto input_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); | |||||
| auto output_type = AnfAlgo::GetOutputInferDataType(node, 0); | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| builder.SetInputsFormat({input_format}); | |||||
| builder.SetInputsDeviceType({input_type}); | |||||
| builder.SetOutputsFormat({output_format}); | |||||
| builder.SetOutputsDeviceType({output_type}); | |||||
| builder.SetKernelType(UNKNOWN_KERNEL_TYPE); | |||||
| builder.SetProcessor(kernel::Processor::CUDA); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get()); | |||||
| } | |||||
| // Insert transpose op between node and used_node whose position is used_node_index. | |||||
| CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node, | |||||
| int used_node_index, const std::vector<int> &transpose_perm) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| // 1.Create a transpose node. | |||||
| auto transpose_prim = std::make_shared<Primitive>(prim::kPrimTranspose->name()); | |||||
| MS_EXCEPTION_IF_NULL(transpose_prim); | |||||
| // 2.Set the input of transpose. | |||||
| std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node}; | |||||
| auto transpose_op = graph->NewCNode(transpose_input); | |||||
| // 3.Set the output info of transpose. | |||||
| auto transpose_type = {AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)}; | |||||
| auto transpose_shape = {AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index)}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape(transpose_type, transpose_shape, transpose_op.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op); | |||||
| // 4.Set the input of used_node. | |||||
| MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope() | |||||
| << ", index: " << used_node_index; | |||||
| AnfAlgo::SetNodeInput(utils::cast<CNodePtr>(used_node), transpose_op, used_node_index); | |||||
| // 5. Update the manager info of transpose op. | |||||
| FuncGraphManagerPtr manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| manager->Clear(); | |||||
| manager->AddFuncGraph(graph); | |||||
| return transpose_op; | |||||
| } | |||||
| } // namespace | |||||
| const AnfNodePtr InsertFormatTransformOp::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| if (!AnfAlgo::IsRealCNodeKernel(node)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto iter = device::gpu::kKernelFormatPositionMap.find(AnfAlgo::GetCNodeName(node)); | |||||
| if (iter == device::gpu::kKernelFormatPositionMap.end()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto origin_data_format = AnfAlgo::GetOriginDataFormat(node); | |||||
| if (origin_data_format == kOpFormat_DEFAULT) { | |||||
| origin_data_format = kOpFormat_NCHW; | |||||
| } | |||||
| MS_LOG(DEBUG) << "Process node: " << node->fullname_with_scope(); | |||||
| // Insert input transpose from origin_data_format to input_format. | |||||
| auto inputs_format = AnfAlgo::GetAllInputFormats(node); | |||||
| for (size_t i = 0; i < inputs_format.size(); i++) { | |||||
| if ((inputs_format[i] != kOpFormat_DEFAULT) && (inputs_format[i] != origin_data_format)) { | |||||
| auto input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i); | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| auto transpose_perm = TransposeAxis(origin_data_format, inputs_format[i]); | |||||
| auto transpose_op = InsertTransposeOp(graph, input_node, node, i, transpose_perm); | |||||
| SetTransposeOpBuildInfo(kOpFormat_DEFAULT, inputs_format[i], transpose_op); | |||||
| } | |||||
| } | |||||
| // Insert output transpose from output_format to origin_data_format. | |||||
| auto outputs_format = AnfAlgo::GetAllOutputFormats(node); | |||||
| for (size_t i = 0; i < outputs_format.size(); i++) { | |||||
| if ((outputs_format[i] != kOpFormat_DEFAULT) && (outputs_format[i] != origin_data_format)) { | |||||
| // Find all nodes connected with node output, and change their inputs to transpose. | |||||
| auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, i); | |||||
| for (size_t j = 0; j < used_node_list->size(); j++) { | |||||
| auto used_node = used_node_list->at(j).first; | |||||
| auto used_node_index = used_node_list->at(j).second - 1; | |||||
| auto transpose_perm = TransposeAxis(outputs_format[i], origin_data_format); | |||||
| if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTupleGetItem->name()) { | |||||
| MS_LOG(DEBUG) << "The used node of [" << node->fullname_with_scope() << "] is tuple item."; | |||||
| // The tuple item need get next used nodes again. | |||||
| ProcessForTupleItem(graph, used_node, used_node_index, transpose_perm, outputs_format[i]); | |||||
| continue; | |||||
| } | |||||
| auto transpose_op = InsertTransposeOp(graph, node, used_node, used_node_index, transpose_perm); | |||||
| SetTransposeOpBuildInfo(outputs_format[i], kOpFormat_DEFAULT, transpose_op); | |||||
| } | |||||
| } | |||||
| } | |||||
| return node; | |||||
| } | |||||
| void InsertFormatTransformOp::ProcessForTupleItem(const FuncGraphPtr &graph, const AnfNodePtr &node, int node_index, | |||||
| const std::vector<int> &transpose_perm, | |||||
| const std::string &transpose_format) const { | |||||
| auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, node_index); | |||||
| for (size_t i = 0; i < used_node_list->size(); i++) { | |||||
| auto used_node = used_node_list->at(i).first; | |||||
| auto used_node_index = used_node_list->at(i).second - 1; | |||||
| if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTupleGetItem->name()) { | |||||
| MS_LOG(EXCEPTION) << "The used node of tuple item can't be tuple item."; | |||||
| } | |||||
| auto transpose_op = InsertTransposeOp(graph, node, used_node, used_node_index, transpose_perm); | |||||
| SetTransposeOpBuildInfo(transpose_format, kOpFormat_DEFAULT, transpose_op); | |||||
| } | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,39 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_INSERT_FORMAT_TRANSFORM_OP_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_INSERT_FORMAT_TRANSFORM_OP_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class InsertFormatTransformOp : public PatternProcessPass { | |||||
| public: | |||||
| explicit InsertFormatTransformOp(bool multigraph = true) | |||||
| : PatternProcessPass("insert_format_transform_op", multigraph) {} | |||||
| ~InsertFormatTransformOp() override = default; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| void ProcessForTupleItem(const FuncGraphPtr &graph, const AnfNodePtr &node, int node_index, | |||||
| const std::vector<int> &transpose_perm, const std::string &transpose_format) const; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_INSERT_FORMAT_TRANSFORM_OP_H_ | |||||
| @@ -0,0 +1,65 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/gpu/remove_format_transform_pair.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "ir/primitive.h" | |||||
| #include "utils/utils.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| const BaseRef RemoveFormatTransformPair::DefinePattern() const { | |||||
| VarPtr X = std::make_shared<Var>(); | |||||
| MS_EXCEPTION_IF_NULL(X); | |||||
| VectorRef transpose1 = VectorRef({prim::kPrimTranspose, X}); | |||||
| VectorRef transpose2 = VectorRef({prim::kPrimTranspose, transpose1}); | |||||
| return transpose2; | |||||
| } | |||||
| const AnfNodePtr RemoveFormatTransformPair::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| MS_LOG(DEBUG) << "Process node:" << node->fullname_with_scope(); | |||||
| auto input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0); | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| if (AnfAlgo::GetCNodeName(node) != prim::kPrimTranspose->name() || | |||||
| AnfAlgo::GetCNodeName(input_node) != prim::kPrimTranspose->name()) { | |||||
| MS_LOG(EXCEPTION) << "The pattern is not transpose pair, " | |||||
| << "node:" << AnfAlgo::GetCNodeName(node) << " node input:" << AnfAlgo::GetCNodeName(input_node); | |||||
| } | |||||
| // If transpose operator used by more than one other operators, it cant not be deleted directly. | |||||
| if (IsUsedByOthers(graph, input_node)) { | |||||
| MS_LOG(DEBUG) << "The transpose node [" << input_node->fullname_with_scope() | |||||
| << "] is used by more than one other operators."; | |||||
| return nullptr; | |||||
| } | |||||
| auto transpose1_input_shape = AnfAlgo::GetInputDeviceShape(input_node, 0); | |||||
| auto transpose2_output_shape = AnfAlgo::GetOutputDeviceShape(node, 0); | |||||
| if (transpose2_output_shape == transpose1_input_shape) { | |||||
| auto transpose1_input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(input_node), 0); | |||||
| MS_EXCEPTION_IF_NULL(transpose1_input_node); | |||||
| return transpose1_input_node; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REMOVE_FORMAT_TRANSFORM_PAIR_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REMOVE_FORMAT_TRANSFORM_PAIR_H_ | |||||
| #include <memory> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class RemoveFormatTransformPair : public PatternProcessPass { | |||||
| public: | |||||
| explicit RemoveFormatTransformPair(bool multigraph = true) | |||||
| : PatternProcessPass("remove_format_transform_pair", multigraph) {} | |||||
| ~RemoveFormatTransformPair() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REMOVE_FORMAT_TRANSFORM_PAIR_H_ | |||||
| @@ -353,6 +353,48 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) { | |||||
| } | } | ||||
| } | } | ||||
| std::vector<std::string> AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!AnfAlgo::IsRealKernel(node)) { | |||||
| MS_LOG(EXCEPTION) << "Not real kernel:" | |||||
| << "#node [" << node->DebugString() << "]"; | |||||
| } | |||||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||||
| auto build_info = kernel_info->select_kernel_build_info(); | |||||
| MS_EXCEPTION_IF_NULL(build_info); | |||||
| auto format = build_info->GetAllOutputFormats(); | |||||
| return format; | |||||
| } | |||||
| std::vector<std::string> AnfRuntimeAlgorithm::GetAllInputFormats(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!AnfAlgo::IsRealKernel(node)) { | |||||
| MS_LOG(EXCEPTION) << "Not real kernel:" | |||||
| << "#node [" << node->DebugString() << "]"; | |||||
| } | |||||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||||
| auto build_info = kernel_info->select_kernel_build_info(); | |||||
| MS_EXCEPTION_IF_NULL(build_info); | |||||
| auto format = build_info->GetAllInputFormats(); | |||||
| return format; | |||||
| } | |||||
| std::string AnfRuntimeAlgorithm::GetOriginDataFormat(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!AnfAlgo::IsRealKernel(node)) { | |||||
| MS_LOG(EXCEPTION) << "Not real kernel:" | |||||
| << "#node [" << node->DebugString() << "]"; | |||||
| } | |||||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||||
| auto build_info = kernel_info->select_kernel_build_info(); | |||||
| MS_EXCEPTION_IF_NULL(build_info); | |||||
| auto format = build_info->GetOriginDataFormat(); | |||||
| return format; | |||||
| } | |||||
| std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) { | std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (output_idx > GetOutputTensorNum(node)) { | if (output_idx > GetOutputTensorNum(node)) { | ||||
| @@ -829,7 +871,7 @@ void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode * | |||||
| bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) { | bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| // parameter and value node is not a real kernel too | |||||
| // parameter and value node is a real kernel too | |||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -101,6 +101,12 @@ class AnfRuntimeAlgorithm { | |||||
| static size_t GetInputTensorNum(const AnfNodePtr &node); | static size_t GetInputTensorNum(const AnfNodePtr &node); | ||||
| // get the num of output real_kernel(which can be build and run in device) | // get the num of output real_kernel(which can be build and run in device) | ||||
| static size_t GetOutputTensorNum(const AnfNodePtr &node); | static size_t GetOutputTensorNum(const AnfNodePtr &node); | ||||
| // get all outputs format select of anf node | |||||
| static std::vector<std::string> GetAllOutputFormats(const AnfNodePtr &node); | |||||
| // get all inputs format select of anf node | |||||
| static std::vector<std::string> GetAllInputFormats(const AnfNodePtr &node); | |||||
| // get origin data format select of anf node | |||||
| static std::string GetOriginDataFormat(const AnfNodePtr &node); | |||||
| // get output format select of anf node | // get output format select of anf node | ||||
| static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx); | static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx); | ||||
| // get input format select of anf node | // get input format select of anf node | ||||
| @@ -30,6 +30,8 @@ | |||||
| #include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h" | #include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h" | ||||
| #include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" | #include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" | ||||
| #include "backend/optimizer/gpu/replace_addn_fusion.h" | #include "backend/optimizer/gpu/replace_addn_fusion.h" | ||||
| #include "backend/optimizer/gpu/insert_format_transform_op.h" | |||||
| #include "backend/optimizer/gpu/remove_format_transform_pair.h" | |||||
| #include "runtime/device/kernel_runtime_manager.h" | #include "runtime/device/kernel_runtime_manager.h" | ||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "common/trans.h" | #include "common/trans.h" | ||||
| @@ -76,6 +78,8 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||||
| void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | ||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | auto optimizer = std::make_shared<opt::GraphOptimizer>(); | ||||
| auto pm = std::make_shared<opt::PassManager>(); | auto pm = std::make_shared<opt::PassManager>(); | ||||
| pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>()); | |||||
| pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>()); | |||||
| pm->AddPass(std::make_shared<opt::AllReduceFusion>()); | pm->AddPass(std::make_shared<opt::AllReduceFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::GetitemTuple>()); | pm->AddPass(std::make_shared<opt::GetitemTuple>()); | ||||
| optimizer->AddPassManager(pm); | optimizer->AddPassManager(pm); | ||||
| @@ -203,7 +207,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList | |||||
| } | } | ||||
| // Assign CUDA streams | // Assign CUDA streams | ||||
| AssignStream(graph); | AssignStream(graph); | ||||
| // Hide NoOp from execution graph | |||||
| // Hide NopOp from execution graph | |||||
| opt::HideNopNode(graph.get()); | opt::HideNopNode(graph.get()); | ||||
| // Build kernel if node is cnode | // Build kernel if node is cnode | ||||
| BuildKernel(graph); | BuildKernel(graph); | ||||
| @@ -213,7 +217,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList | |||||
| graph->set_execution_order(execution_order); | graph->set_execution_order(execution_order); | ||||
| // Get summary nodes. | // Get summary nodes. | ||||
| SetSummaryNodes(graph.get()); | SetSummaryNodes(graph.get()); | ||||
| // Remove NoOp from execution graph | |||||
| // Remove NopOp from execution graph | |||||
| opt::RemoveNopNode(graph.get()); | opt::RemoveNopNode(graph.get()); | ||||
| // Set graph manager. | // Set graph manager. | ||||
| MS_EXCEPTION_IF_NULL(context_); | MS_EXCEPTION_IF_NULL(context_); | ||||
| @@ -272,7 +276,7 @@ void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_in | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| SelectKernel(kernel_graph); | SelectKernel(kernel_graph); | ||||
| StartKernelRT(); | StartKernelRT(); | ||||
| // Hide NoOp from execution graph | |||||
| // Hide NopOp from execution graph | |||||
| opt::HideNopNode(kernel_graph.get()); | opt::HideNopNode(kernel_graph.get()); | ||||
| BuildKernel(kernel_graph); | BuildKernel(kernel_graph); | ||||
| run_op_graphs_[graph_info] = kernel_graph; | run_op_graphs_[graph_info] = kernel_graph; | ||||
| @@ -282,7 +286,7 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph | |||||
| const std::vector<tensor::TensorPtr> &input_tensors) { | const std::vector<tensor::TensorPtr> &input_tensors) { | ||||
| auto kernel_graph = run_op_graphs_[graph_info]; | auto kernel_graph = run_op_graphs_[graph_info]; | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| // Remove NoOp from execution graph | |||||
| // Remove NopOp from execution graph | |||||
| opt::RemoveNopNode(kernel_graph.get()); | opt::RemoveNopNode(kernel_graph.get()); | ||||
| RunOpAllocateMemory(op_run_info.value, input_tensors, kernel_graph.get()); | RunOpAllocateMemory(op_run_info.value, input_tensors, kernel_graph.get()); | ||||
| // Execute the computation | // Execute the computation | ||||
| @@ -252,7 +252,7 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "invalid input Data, please check it."); | return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "invalid input Data, please check it."); | ||||
| } else { | } else { | ||||
| if (!stop_send_) { | if (!stop_send_) { | ||||
| MS_LOG(WARNING) << "Retry pushing data..."; | |||||
| MS_LOG(DEBUG) << "Retry pushing data..."; | |||||
| continue; | continue; | ||||
| } | } | ||||
| break; | break; | ||||
| @@ -96,10 +96,10 @@ BlockQueueStatus_T BlockingQueue::Create(void *addr, const std::vector<size_t> & | |||||
| void BlockingQueue::RegisterRelease(const std::function<void(void *)> &func) { queue_->RegisterRelease(func); } | void BlockingQueue::RegisterRelease(const std::function<void(void *)> &func) { queue_->RegisterRelease(func); } | ||||
| BlockQueueStatus_T BlockingQueue::Push(const std::vector<DataItemGpu> &data, unsigned int timeout_in_sec) { | |||||
| BlockQueueStatus_T BlockingQueue::Push(const std::vector<DataItemGpu> &data, unsigned int) { | |||||
| std::unique_lock<std::mutex> locker(mutex_); | std::unique_lock<std::mutex> locker(mutex_); | ||||
| if (queue_->IsFull()) { | if (queue_->IsFull()) { | ||||
| if (not_full_cond_.wait_for(locker, std::chrono::seconds(timeout_in_sec)) == std::cv_status::timeout) { | |||||
| if (not_full_cond_.wait_for(locker, std::chrono::microseconds(100)) == std::cv_status::timeout) { | |||||
| return TIMEOUT; | return TIMEOUT; | ||||
| } | } | ||||
| } | } | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include "backend/kernel_compiler/kernel.h" | #include "backend/kernel_compiler/kernel.h" | ||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "utils/ms_context.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | ||||
| #include "backend/kernel_compiler/kernel_build_info.h" | #include "backend/kernel_compiler/kernel_build_info.h" | ||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| @@ -157,25 +158,87 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type) { | |||||
| auto ms_context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(ms_context); | |||||
| if (ms_context->execution_mode() == kPynativeMode) { | |||||
| return false; | |||||
| } | |||||
| if (!AnfAlgo::IsRealCNodeKernel(kernel_node)) { | |||||
| return false; | |||||
| } | |||||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| auto iter = kKernelFormatPositionMap.find(kernel_name); | |||||
| if (iter == kKernelFormatPositionMap.end()) { | |||||
| return false; | |||||
| } | |||||
| if (inputs_type.size() == 0) { | |||||
| return false; | |||||
| } | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| if (input_shape.size() != 4) { | |||||
| return false; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type, | |||||
| std::vector<std::string> *inputs_format, std::vector<std::string> *outputs_format, | |||||
| std::string *origin_data_format) { | |||||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| auto iter = kKernelFormatPositionMap.find(kernel_name); | |||||
| if (iter == kKernelFormatPositionMap.end()) { | |||||
| return; | |||||
| } | |||||
| auto cal_format = (inputs_type[0] == kNumberTypeFloat16) ? kOpFormat_NHWC : kOpFormat_NCHW; | |||||
| MS_LOG(DEBUG) << "Kernel node: " << kernel_node->fullname_with_scope() << ", format: " << cal_format; | |||||
| auto inputs_format_position = iter->second.first; | |||||
| for (const auto &input_format_position : inputs_format_position) { | |||||
| if (input_format_position >= inputs_format->size()) { | |||||
| MS_LOG(EXCEPTION) << "The position [" << input_format_position << "] is out of range of the input size [" | |||||
| << inputs_format->size() << "] #kernel_node [" << kernel_node->fullname_with_scope() << "]"; | |||||
| } | |||||
| (*inputs_format)[input_format_position] = cal_format; | |||||
| } | |||||
| auto outputs_format_position = iter->second.second; | |||||
| for (const auto &output_format_position : outputs_format_position) { | |||||
| if (output_format_position >= outputs_format->size()) { | |||||
| MS_LOG(EXCEPTION) << "The position [" << output_format_position << "] is out of range of the output size [" | |||||
| << outputs_format->size() << "] #kernel_node [" << kernel_node->fullname_with_scope() << "]"; | |||||
| } | |||||
| (*outputs_format)[output_format_position] = cal_format; | |||||
| } | |||||
| auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| if (prim->HasAttr("data_format")) { | |||||
| *origin_data_format = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "data_format"); | |||||
| } | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| void SetKernelInfo(const CNodePtr &kernel_node) { | void SetKernelInfo(const CNodePtr &kernel_node) { | ||||
| std::vector<std::string> inputs_format; | std::vector<std::string> inputs_format; | ||||
| std::vector<TypeId> inputs_type; | std::vector<TypeId> inputs_type; | ||||
| std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder = | |||||
| std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | ||||
| inputs_format.emplace_back(kOpFormat_DEFAULT); | inputs_format.emplace_back(kOpFormat_DEFAULT); | ||||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); | inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); | ||||
| } | } | ||||
| builder->SetInputsFormat(inputs_format); | |||||
| builder->SetInputsDeviceType(inputs_type); | |||||
| std::vector<std::string> outputs_format; | std::vector<std::string> outputs_format; | ||||
| std::vector<TypeId> outputs_type; | std::vector<TypeId> outputs_type; | ||||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { | for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { | ||||
| outputs_format.emplace_back(kOpFormat_DEFAULT); | outputs_format.emplace_back(kOpFormat_DEFAULT); | ||||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); | outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); | ||||
| } | } | ||||
| std::string origin_data_format = kOpFormat_DEFAULT; | |||||
| if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) { | |||||
| UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format); | |||||
| } | |||||
| std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder = | |||||
| std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| builder->SetOriginDataFormat(origin_data_format); | |||||
| builder->SetInputsFormat(inputs_format); | |||||
| builder->SetInputsDeviceType(inputs_type); | |||||
| builder->SetOutputsFormat(outputs_format); | builder->SetOutputsFormat(outputs_format); | ||||
| builder->SetOutputsDeviceType(outputs_type); | builder->SetOutputsDeviceType(outputs_type); | ||||
| @@ -20,13 +20,35 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/dtype.h" | #include "ir/dtype.h" | ||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "frontend/operator/ops.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace gpu { | namespace gpu { | ||||
| // map<opName, (inputFormatPosition, outputFormatPosition)>, used for getting the insert position of format transform. | |||||
| static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>> kKernelFormatPositionMap = { | |||||
| {prim::kPrimConv2D->name(), {{0, 1}, {0}}}, | |||||
| {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {0}}}, | |||||
| {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {0}}}, | |||||
| {prim::kPrimRelu->name(), {{0}, {0}}}, | |||||
| {prim::kPrimReluGrad->name(), {{0, 1}, {0}}}, | |||||
| {prim::kPrimMaxPool->name(), {{0}, {0}}}, | |||||
| {prim::kPrimMaxPoolGrad->name(), {{0, 1, 2}, {0}}}, | |||||
| {kAvgPoolOpName, {{0}, {0}}}, | |||||
| {kAvgPoolGradGpuOpName, {{0, 1, 2}, {0}}}, | |||||
| {kTensorAddOpName, {{0, 1}, {0}}}, | |||||
| {kFusedBatchNormEx, {{0}, {0}}}, | |||||
| {kFusedBatchNormExWithActivation, {{0}, {0}}}, | |||||
| {kFusedBatchNormExWithAddAndActivation, {{0, 5}, {0}}}, | |||||
| {kFusedBatchNormGradEx, {{0, 1}, {0}}}, | |||||
| {kFusedBatchNormGradExWithActivation, {{0, 1, 7}, {0}}}, | |||||
| {kFusedBatchNormGradExWithAddAndActivation, {{0, 1, 7}, {0, 3}}}, | |||||
| }; | |||||
| void SetKernelInfo(const CNodePtr &apply_kernel_ptr); | void SetKernelInfo(const CNodePtr &apply_kernel_ptr); | ||||
| class KernelAttr { | class KernelAttr { | ||||
| @@ -189,6 +189,9 @@ constexpr auto kPullOpName = "Pull"; | |||||
| constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup"; | constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup"; | ||||
| constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy"; | constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy"; | ||||
| constexpr auto kPaddingOpName = "Padding"; | constexpr auto kPaddingOpName = "Padding"; | ||||
| constexpr auto kAvgPoolOpName = "AvgPool"; | |||||
| constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; | |||||
| constexpr auto kTensorAddOpName = "TensorAdd"; | |||||
| // attr key name | // attr key name | ||||
| constexpr auto kAttrInputNames = "input_names"; | constexpr auto kAttrInputNames = "input_names"; | ||||