| @@ -147,6 +147,39 @@ bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspor | |||
| return true; | |||
| } | |||
| void GenNoneInputDescJson(const std::shared_ptr<OpIOInfo> &input_ptr, size_t input_i, | |||
| std::vector<nlohmann::json> *input_list) { | |||
| nlohmann::json input_desc_json; | |||
| auto in_name = input_ptr->name(); | |||
| input_desc_json[kJName] = in_name + std::to_string(input_i); | |||
| input_desc_json[kJValid] = false; | |||
| input_list->emplace_back(input_desc_json); | |||
| } | |||
| void TbeKernelJsonCreator::GenValidInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index, | |||
| bool value, const std::shared_ptr<OpIOInfo> &input_ptr, | |||
| const string &op_input_name, size_t input_i, | |||
| std::vector<nlohmann::json> *input_list) { | |||
| auto dtype = GetDeviceInputType(anf_node, real_input_index); | |||
| auto format = GetDeviceInputFormat(anf_node, real_input_index); | |||
| auto shape = GetDeviceInputShape(anf_node, real_input_index); | |||
| auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index); | |||
| if (ori_shape.empty()) { | |||
| ori_shape.emplace_back(1); | |||
| } | |||
| nlohmann::json input_desc_json; | |||
| input_desc_json[kJDtype] = dtype; | |||
| input_desc_json[kJName] = op_input_name + std::to_string(input_i); | |||
| input_desc_json[kJOriShape] = ori_shape; | |||
| input_desc_json[kJOriFormat] = kOpFormat_NCHW; | |||
| input_desc_json[kJShape] = shape; | |||
| input_desc_json[kJFormat] = format; | |||
| input_desc_json[kJValid] = value; | |||
| input_desc_json[kJParamType] = input_ptr->param_type(); | |||
| input_desc_json[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index); | |||
| input_list->emplace_back(input_desc_json); | |||
| } | |||
| bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index, | |||
| bool value, const std::shared_ptr<OpIOInfo> &input_ptr, | |||
| const string &op_input_name, size_t input_i, | |||
| @@ -156,32 +189,19 @@ bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr<AnfNode> &anf_ | |||
| MS_EXCEPTION_IF_NULL(input_list); | |||
| std::string op_name = AnfAlgo::GetCNodeName(anf_node); | |||
| if (op_name == kDynamicRNNOpName && input_ptr->name() == "seq_length") { | |||
| nlohmann::json input_desc_json; | |||
| auto in_name = input_ptr->name(); | |||
| input_desc_json[kJName] = in_name + std::to_string(input_i); | |||
| input_desc_json[kJValid] = false; | |||
| input_list->emplace_back(input_desc_json); | |||
| GenNoneInputDescJson(input_ptr, input_i, input_list); | |||
| } else if (op_name == kDynamicGRUV2OpName) { | |||
| auto none_index = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(anf_node, "placeholder_index"); | |||
| auto item = find(none_index.begin(), none_index.end(), input_ptr->index()); | |||
| if (item != none_index.end()) { | |||
| GenNoneInputDescJson(input_ptr, input_i, input_list); | |||
| } else { | |||
| GenValidInputDescJson(anf_node, real_input_index, value, input_ptr, op_input_name, input_i, input_list); | |||
| } | |||
| } else if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) { | |||
| TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_); | |||
| } else { | |||
| auto dtype = GetDeviceInputType(anf_node, real_input_index); | |||
| auto format = GetDeviceInputFormat(anf_node, real_input_index); | |||
| auto shape = GetDeviceInputShape(anf_node, real_input_index); | |||
| auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index); | |||
| if (ori_shape.empty()) { | |||
| ori_shape.emplace_back(1); | |||
| } | |||
| nlohmann::json input_desc_json; | |||
| input_desc_json[kJDtype] = dtype; | |||
| input_desc_json[kJName] = op_input_name + std::to_string(input_i); | |||
| input_desc_json[kJOriShape] = ori_shape; | |||
| input_desc_json[kJOriFormat] = kOpFormat_NCHW; | |||
| input_desc_json[kJShape] = shape; | |||
| input_desc_json[kJFormat] = format; | |||
| input_desc_json[kJValid] = value; | |||
| input_desc_json[kJParamType] = input_ptr->param_type(); | |||
| input_desc_json[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index); | |||
| input_list->emplace_back(input_desc_json); | |||
| GenValidInputDescJson(anf_node, real_input_index, value, input_ptr, op_input_name, input_i, input_list); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -111,6 +111,9 @@ class TbeKernelJsonCreator { | |||
| void GenOutputList(const std::shared_ptr<AnfNode> &anf_node, const size_t &output_obj_num, | |||
| const std::shared_ptr<OpIOInfo> &output_ptr, size_t *output_idx, | |||
| std::vector<nlohmann::json> *output_list); | |||
| void GenValidInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index, bool value, | |||
| const std::shared_ptr<OpIOInfo> &input_ptr, const string &op_input_name, size_t input_i, | |||
| std::vector<nlohmann::json> *input_list); | |||
| std::vector<size_t> GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const; | |||
| std::string GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const; | |||
| std::string GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const; | |||
| @@ -63,6 +63,7 @@ | |||
| #include "backend/optimizer/ascend/format_type/insert_trans_op.h" | |||
| #include "backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h" | |||
| #include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h" | |||
| #include "backend/optimizer/ascend/format_type/insert_transpose_for_dyanmic_gru_v2.h" | |||
| #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" | |||
| #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" | |||
| #include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" | |||
| @@ -110,6 +111,7 @@ | |||
| #include "backend/optimizer/ascend/ir_fission/pack_fission.h" | |||
| #include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h" | |||
| #include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h" | |||
| #include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_gru.h" | |||
| #include "utils/ms_context.h" | |||
| #include "backend/optimizer/graph_kernel/composite_ops_fusion.h" | |||
| #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | |||
| @@ -222,6 +224,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) | |||
| data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | |||
| data_layout_pm->AddPass(std::make_shared<RemoveReshapePair>()); | |||
| data_layout_pm->AddPass(std::make_shared<EliminateRedundantOp>()); | |||
| data_layout_pm->AddPass(std::make_shared<InsertTransposeForDynamicGRUV2>()); | |||
| data_layout_pm->AddPass(std::make_shared<OptimizeDependence>()); | |||
| data_layout_pm->AddPass(std::make_shared<TransDataSplit>()); | |||
| data_layout_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||
| @@ -278,6 +281,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||
| ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<DynamicRnnGradFissionV2>()); | |||
| AddAscendIRFusionRulesPass(ir_fusion_pm.get()); | |||
| AddAscendIRFusionPass(ir_fusion_pm.get()); | |||
| @@ -0,0 +1,82 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_gru.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "utils/utils.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "base/core_ops.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef InsertPlaceholderForDynamicGRUV2::DefinePattern() const { | |||
| std::shared_ptr<Var> V = std::make_shared<CondVar>(UnVisited); | |||
| std::shared_ptr<Var> Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({V, Xs}); | |||
| } | |||
| const AnfNodePtr InsertPlaceholderForDynamicGRUV2::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||
| if (op_name != kDynamicGRUV2OpName) { | |||
| return nullptr; | |||
| } | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||
| auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(node); | |||
| if (input_num == 0) { | |||
| return nullptr; | |||
| } | |||
| std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; | |||
| auto none_index = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, "placeholder_index"); | |||
| size_t real_input_index = 0; | |||
| for (size_t in_idx = 0; in_idx < input_num + none_index.size(); in_idx++) { | |||
| auto item = find(none_index.begin(), none_index.end(), in_idx); | |||
| if (item != none_index.end()) { | |||
| auto value = std::make_shared<None>(); | |||
| auto value_node = NewValueNode(value); | |||
| value_node->set_abstract(std::make_shared<abstract::AbstractNone>()); | |||
| auto new_node = kernel_graph->NewValueNode(value_node); | |||
| kernel_graph->AddValueNodeToGraph(new_node); | |||
| new_inputs.push_back(new_node); | |||
| } else { | |||
| auto input_node = AnfAlgo::GetInputNode(cnode, real_input_index); | |||
| new_inputs.push_back(input_node); | |||
| real_input_index++; | |||
| } | |||
| } | |||
| CNodePtr new_node = nullptr; | |||
| if (kernel_graph == nullptr) { | |||
| new_node = std::make_shared<CNode>(*cnode); | |||
| } else { | |||
| new_node = kernel_graph->NewCNode(cnode); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| new_node->set_inputs(new_inputs); | |||
| return new_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_GRU_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_GRU_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class InsertPlaceholderForDynamicGRUV2 : public PatternProcessPass { | |||
| public: | |||
| explicit InsertPlaceholderForDynamicGRUV2(bool multigraph = true) | |||
| : PatternProcessPass("add_placeholder_for_dynamic_gru", multigraph) {} | |||
| ~InsertPlaceholderForDynamicGRUV2() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_PLACEHOLDER_FOR_DYNAMIC_GRU_H_ | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_DYANMIC_GRU_V2_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_DYANMIC_GRU_V2_H_ | |||
| #include <string> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class InsertTransposeForDynamicGRUV2 : public PatternProcessPass { | |||
| public: | |||
| explicit InsertTransposeForDynamicGRUV2(bool multigraph = true) | |||
| : PatternProcessPass("insert_transpose_for_dynamic_gru_v2_op", multigraph) {} | |||
| ~InsertTransposeForDynamicGRUV2() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_DYANMIC_GRU_V2_H_ | |||
| @@ -0,0 +1,94 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/format_type/insert_transpose_for_dyanmic_gru_v2.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "utils/utils.h" | |||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "runtime/device/kernel_info.h" | |||
| #include "backend/kernel_compiler/oplib/oplib.h" | |||
| #include "utils/ms_context.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef InsertTransposeForDynamicGRUV2::DefinePattern() const { | |||
| VarPtr X = std::make_shared<Var>(); | |||
| VarPtr X1 = std::make_shared<Var>(); | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| MS_EXCEPTION_IF_NULL(X); | |||
| MS_EXCEPTION_IF_NULL(X1); | |||
| MS_EXCEPTION_IF_NULL(Xs); | |||
| return VectorRef( | |||
| {prim::kPrimDynamicGRUV2, X1, VectorRef({prim::KPrimTransData, VectorRef({prim::kPrimReshape, X})}), Xs}); | |||
| } | |||
| CNodePtr Insert(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (size_t index = 0; index < cnode->inputs().size(); index++) { | |||
| if (index == 1 || index == 2) { | |||
| AnfNodePtr new_node = nullptr; | |||
| AnfNodePtr new_transdata_node = nullptr; | |||
| AnfNodePtr new_transpose_node = nullptr; | |||
| AnfNodePtr transdata_node = AnfAlgo::GetInputNode(cnode, index); | |||
| AnfNodePtr reshape_node = AnfAlgo::GetInputNode(transdata_node->cast<CNodePtr>(), 0); | |||
| auto input_format = AnfAlgo::GetInputFormat(transdata_node, 0); | |||
| auto output_format = AnfAlgo::GetOutputFormat(transdata_node, 0); | |||
| auto padding_axis = AnfAlgo::GetOutputReshapeType(transdata_node, 0); | |||
| KernelSelectPtr kernel_select = std::make_shared<KernelSelect>(); | |||
| // trans default to hwcn | |||
| new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(transdata_node->cast<CNodePtr>(), 0), | |||
| kernel_select, false, prim::kPrimTranspose->name()); | |||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int64_t>{2, 3, 1, 0}), new_transpose_node); | |||
| AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), new_transpose_node); | |||
| RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node); | |||
| // trans hwcn to output_format | |||
| new_transdata_node = | |||
| NewTransOpNode(func_graph, new_transpose_node, kernel_select, false, prim::KPrimTransData->name()); | |||
| RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node, padding_axis); | |||
| new_transdata_node->set_abstract(transdata_node->abstract()); | |||
| new_node = new_transdata_node; | |||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->AddFuncGraph(func_graph); | |||
| if (!manager->Replace(transdata_node, new_node)) { | |||
| MS_LOG(EXCEPTION) << "For DynamicGRUV2, manager replace node failed"; | |||
| } | |||
| } | |||
| } | |||
| return cnode; | |||
| } | |||
| const AnfNodePtr InsertTransposeForDynamicGRUV2::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||
| CNodePtr new_node = nullptr; | |||
| if (op_name == kDynamicGRUV2OpName) { | |||
| new_node = Insert(func_graph, cnode); | |||
| } | |||
| return new_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -59,7 +59,7 @@ bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildIn | |||
| auto name = AnfAlgo::GetCNodeName(cnode); | |||
| for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { | |||
| TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); | |||
| if (name == kDynamicRNNOpName && input_origin_type == kMetaTypeNone) { | |||
| if ((name == kDynamicRNNOpName || name == kDynamicGRUV2OpName) && input_origin_type == kMetaTypeNone) { | |||
| continue; | |||
| } | |||
| if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) { | |||
| @@ -133,6 +133,13 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i | |||
| if (op_name == kDynamicRNNOpName && i == 3) { | |||
| continue; | |||
| } | |||
| if (op_name == kDynamicGRUV2OpName) { | |||
| auto none_index = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(anf_node_ptr, "placeholder_index"); | |||
| auto item = find(none_index.begin(), none_index.end(), i); | |||
| if (item != none_index.end()) { | |||
| continue; | |||
| } | |||
| } | |||
| auto real_input_index = AnfAlgo::GetRealInputIndex(anf_node_ptr, i); | |||
| auto device_address = AnfAlgo::GetPrevNodeOutputAddr(anf_node_ptr, real_input_index); | |||
| AddressPtr input = std::make_shared<Address>(); | |||
| @@ -227,8 +227,8 @@ constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad"; | |||
| constexpr auto kBasicLSTMCellOpName = "BasicLSTMCell"; | |||
| constexpr auto kDynamicRNNOpName = "DynamicRNN"; | |||
| constexpr auto kLSTMInputGradOpName = "LSTMInputGrad"; | |||
| constexpr auto kDynamicGRUOpName = "DynamicGRU"; | |||
| constexpr auto kGRUV2HiddenGrad = "GRUV2HiddenGrad"; | |||
| constexpr auto kDynamicGRUV2OpName = "DynamicGRUV2"; | |||
| constexpr auto kGRUV2HiddenGradOpName = "GRUV2HiddenGrad"; | |||
| constexpr auto kFusedSparseFtrlName = "FusedSparseFtrl"; | |||
| constexpr auto kFusedSparseProximalAdagradName = "FusedSparseProximalAdagrad"; | |||
| constexpr auto kFusedSparseLazyAdamName = "FusedSparseLazyAdam"; | |||
| @@ -239,6 +239,7 @@ constexpr auto kLARSUpdateName = "LARSUpdate"; | |||
| constexpr auto kBasicLSTMCellCStateGradOpName = "BasicLSTMCellCStateGrad"; | |||
| constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2"; | |||
| constexpr auto kMatMulV2OpName = "MatMulV2"; | |||
| constexpr auto kBroadcastToOpName = "BroadcastTo"; | |||
| // Hcom Op Type | |||
| constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce"; | |||
| @@ -34,8 +34,8 @@ dynamic_gru_v2_op_info = TBERegOp("DynamicGRUV2") \ | |||
| .attr("is_training", "optional", "bool", "all", "true") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "weight_input", False, "required", "all") \ | |||
| .input(2, "weight_hidden", False, "required", "all") \ | |||
| .input(1, "weight_input", False, "required", "all", reshape_type="CN") \ | |||
| .input(2, "weight_hidden", False, "required", "all", reshape_type="CN") \ | |||
| .input(3, "bias_input", False, "optional", "all") \ | |||
| .input(4, "bias_hidden", False, "optional", "all") \ | |||
| .input(5, "seq_length", False, "optional", "all") \ | |||
| @@ -22,7 +22,7 @@ gru_v2_hidden_grad_op_info = TBERegOp("GRUV2HiddenGrad") \ | |||
| .binfile_name("gru_v2_hidden_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("gru_v2_hidden_grad") \ | |||
| .attr("gate_order", "optional", "str", "all", "zrh") \ | |||
| .attr("gate_order", "optional", "str", "all", "rzh") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "weight_input", False, "required", "all") \ | |||
| .input(1, "init_h", False, "required", "all") \ | |||
| @@ -1210,7 +1210,7 @@ class DynamicGRUV2Grad(PrimitiveWithInfer): | |||
| num_proj=0, | |||
| time_major=True, | |||
| bias_type="double_bias", | |||
| gate_order="zrh", | |||
| gate_order="rzh", | |||
| reset_after=True): | |||
| self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name) | |||
| self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | |||
| @@ -1266,12 +1266,13 @@ class DynamicGRUV2Grad(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, y_dtype, init_h_dtype, h_dtype, | |||
| dy_dtype, dh_dtype, update_dtype, reset_dtype, new_dtype, hnew_dtype, seq_dtype, mask_dtype): | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| args = {"y_dtype": y_dtype, "init_h_dtype": init_h_dtype, "h_dtype": h_dtype, | |||
| "dy_dtype": dy_dtype, "dh_dtype": dh_dtype, "update_dtype": update_dtype, | |||
| "reset_dtype": reset_dtype, "new_dtype": new_dtype, "hnew_dtype": hnew_dtype} | |||
| args = {"y_dtype": y_dtype, "h_dtype": h_dtype, "dy_dtype": dy_dtype, | |||
| "dh_dtype": dh_dtype, "update_dtype": update_dtype, "reset_dtype": reset_dtype, | |||
| "new_dtype": new_dtype, "hnew_dtype": hnew_dtype} | |||
| validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_types, self.name) | |||
| validator.check_tensor_dtype_valid("winput_dtype", winput_dtype, valid_types, self.name) | |||
| validator.check_tensor_dtype_valid("whidden_dtype", whidden_dtype, valid_types, self.name) | |||
| validator.check_tensor_dtype_valid("init_h_dtype", init_h_dtype, valid_types, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name) | |||
| if seq_dtype is not None: | |||
| validator.check_tensor_dtype_valid("seq_dtype", seq_dtype, valid_types, self.name) | |||
| @@ -549,41 +549,49 @@ class DynamicGRUV2(PrimitiveWithInfer): | |||
| self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape): | |||
| def infer_shape(self, x_shape, winput_shape, whidden_shape, | |||
| binput_shape=None, bhidden_shape=None, seq_shape=None, h_shape=None): | |||
| validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name) | |||
| validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name) | |||
| validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name) | |||
| num_step, batch_size, input_size = x_shape | |||
| hidden_size = winput_shape[-1] // 3 | |||
| if winput_shape[-1] % 3 != 0: | |||
| raise ValueError(f"For {self.name}, weight_input_shape[-1] should multiple of 3.") | |||
| self.placeholder_index = [3, 4, 5, 6] | |||
| if binput_shape is not None: | |||
| validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name) | |||
| validator.check("bias_input_shape", binput_shape, "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name) | |||
| self.placeholder_index.remove(3) | |||
| if bhidden_shape is not None: | |||
| validator.check_int(len(bhidden_shape), 1, Rel.EQ, "bias hidden shape rank", self.name) | |||
| validator.check("bias_hidden_shape", bhidden_shape, | |||
| "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name) | |||
| self.placeholder_index.remove(4) | |||
| if h_shape is not None: | |||
| validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name) | |||
| validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name) | |||
| validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name) | |||
| self.placeholder_index.remove(6) | |||
| if seq_shape is not None: | |||
| raise ValueError(f"For {self.name}, seq_shape should be None.") | |||
| num_step, batch_size, input_size = x_shape | |||
| hidden_size = winput_shape[-1] // 3 | |||
| if winput_shape[-1] % 3 != 0: | |||
| raise ValueError(f"For {self.name}, weight_input_shape[-1] should multiple of 3.") | |||
| validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]", | |||
| whidden_shape[-1], Rel.EQ, self.name) | |||
| validator.check("bias_input_shape", binput_shape, "bias_hidden_shape", bhidden_shape, Rel.EQ, self.name) | |||
| validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name) | |||
| validator.check("weight_hidden_shape[0]", whidden_shape[0], "hidden_size", hidden_size, Rel.EQ, self.name) | |||
| if h_shape is not None: | |||
| validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name) | |||
| validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name) | |||
| if self.num_proj > 0: | |||
| y_shape = (num_step, batch_size, min(hidden_size, self.num_proj)) | |||
| else: | |||
| y_shape = (num_step, batch_size, hidden_size) | |||
| outh_shape = (num_step, batch_size, hidden_size) | |||
| self.add_prim_attr("placeholder_index", self.placeholder_index) | |||
| return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape | |||
| def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype): | |||
| def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, | |||
| binput_dtype=None, bhidden_dtype=None, seq_dtype=None, h_dtype=None): | |||
| validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name) | |||
| validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name) | |||
| validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name) | |||
| @@ -0,0 +1,43 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class DynamicGRUV2(nn.Cell): | |||
| def __init__(self): | |||
| super(DynamicGRUV2, self).__init__() | |||
| self.dynamic_gru = P.DynamicGRUV2() | |||
| def construct(self, x, weight_i, weight_h, bias_i, bias_h, init_h): | |||
| return self.dynamic_gru(x, weight_i, weight_h, bias_i, bias_h, None, init_h) | |||
| def test_dynamic_gru_v2(): | |||
| x = Tensor(np.random.rand(2, 8, 64).astype(np.float16)) | |||
| weight_i = Tensor(np.random.rand(64, 48).astype(np.float16)) | |||
| weight_h = Tensor(np.random.rand(16, 48).astype(np.float16)) | |||
| bias_i = Tensor(np.random.rand(48).astype(np.float16)) | |||
| bias_h = Tensor(np.random.rand(48).astype(np.float16)) | |||
| init_h = Tensor(np.random.rand(8, 16).astype(np.float16)) | |||
| gru_net = DynamicGRUV2() | |||
| output = gru_net(x, weight_i, weight_h, bias_i, bias_h, init_h) | |||
| print(output) | |||
| @@ -2532,11 +2532,7 @@ test_case_other_ops = [ | |||
| Tensor(np.random.rand(48).astype(np.float16)), | |||
| Tensor(np.random.rand(48).astype(np.float16)), | |||
| Tensor(np.random.rand(8, 16).astype(np.float16))], | |||
| 'desc_bprop': [Tensor(np.random.rand(2, 8, 16).astype(np.float16)), | |||
| Tensor(np.random.rand(2, 8, 16).astype(np.float16)), | |||
| Tensor(np.random.rand(2, 8, 16).astype(np.float16)), | |||
| Tensor(np.random.rand(2, 8, 16).astype(np.float16)), | |||
| Tensor(np.random.rand(2, 8, 16).astype(np.float16))]}), | |||
| 'skip': ['backward']}), | |||
| ] | |||
| test_case_quant_ops = [ | |||