| @@ -58,6 +58,7 @@ | |||
| #include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" | |||
| #include "backend/optimizer/ascend/format_type/insert_trans_op.h" | |||
| #include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h" | |||
| #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" | |||
| #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" | |||
| #include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" | |||
| @@ -284,6 +285,9 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||
| } | |||
| ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForHcclOp>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<AddInputToOutput>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<InsertTranspose>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||
| optimizer->AddPassManager(ir_fusion_pm); | |||
| (void)optimizer->Optimize(kernel_graph); | |||
| kernel_graph->SetExecOrderByDefault(); | |||
| @@ -142,6 +142,15 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An | |||
| return node; | |||
| } | |||
| void ReFreshInferShape(const AnfNodePtr &node, const std::string &op_name) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (op_name == kBasicLSTMCellWeightGradOpName && AnfAlgo::GetCNodeName(node) == prim::kPrimReshape->name()) { | |||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); | |||
| auto type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); | |||
| AnfAlgo::SetOutputInferTypeAndShape({type}, {{shape[0], shape[1]}}, node.get()); | |||
| } | |||
| } | |||
| AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const KernelSelectPtr &kernel_select) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -149,6 +158,10 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const | |||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| size_t out_num = AnfAlgo::GetOutputTensorNum(node); | |||
| std::string op_name; | |||
| if (node->isa<CNode>()) { | |||
| op_name = AnfAlgo::GetCNodeName(node); | |||
| } | |||
| for (size_t output_idx = 0; output_idx < out_num; ++output_idx) { | |||
| std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx); | |||
| if (output_format == kOpFormat_NC1KHKWHWC0) { | |||
| @@ -159,6 +172,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const | |||
| std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); | |||
| if (origin_shape.size() > 1 && kCommonFormatSet.find(output_format) == kCommonFormatSet.end()) { | |||
| auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false); | |||
| ReFreshInferShape(trans_op, op_name); | |||
| if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) { | |||
| kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0); | |||
| } | |||
| @@ -0,0 +1,100 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "utils/utils.h" | |||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "runtime/device/kernel_info.h" | |||
| #include "backend/kernel_compiler/oplib/oplib.h" | |||
| #include "utils/ms_context.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef InsertTranspose::DefinePattern() const { | |||
| std::shared_ptr<Var> V = std::make_shared<CondVar>(UnVisited); | |||
| std::shared_ptr<Var> Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({V, Xs}); | |||
| } | |||
| CNodePtr Insert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::string &op_name) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| CNodePtr new_node = nullptr; | |||
| std::vector<AnfNodePtr> transpose_inputs; | |||
| auto prim = std::make_shared<Primitive>(prim::kPrimTranspose->name()); | |||
| transpose_inputs.push_back(NewValueNode(prim)); | |||
| if (op_name == kBasicLSTMCellInputGradOpName) { | |||
| auto origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 1); | |||
| auto origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 1); | |||
| auto dst_shape = {origin_shape[1], origin_shape[0]}; | |||
| transpose_inputs.push_back(AnfAlgo::GetInputNode(cnode, 1)); | |||
| CNodePtr transpose = func_graph->NewCNode(transpose_inputs); | |||
| MS_EXCEPTION_IF_NULL(transpose); | |||
| AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {dst_shape}, transpose.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{1, 0}), transpose); | |||
| AnfAlgo::SetNodeInput(cnode, transpose, 1); | |||
| if (kernel_graph == nullptr) { | |||
| new_node = std::make_shared<CNode>(*cnode); | |||
| } else { | |||
| new_node = kernel_graph->NewCNode(cnode); | |||
| } | |||
| } else if (op_name == kBasicLSTMCellWeightGradOpName) { | |||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | |||
| size_t out_num = AnfAlgo::GetOutputTensorNum(cnode); | |||
| for (size_t output_idx = 0; output_idx < out_num; output_idx++) { | |||
| auto tuple_getitem = CreatTupleGetItemNode(func_graph, cnode, output_idx); | |||
| auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); | |||
| if (origin_shape.size() > 1 && output_idx == 0) { | |||
| auto dtype = AnfAlgo::GetOutputInferDataType(cnode, output_idx); | |||
| auto dst_shape = {origin_shape[0], origin_shape[1]}; | |||
| transpose_inputs.push_back(tuple_getitem); | |||
| CNodePtr transpose = func_graph->NewCNode(transpose_inputs); | |||
| MS_EXCEPTION_IF_NULL(transpose); | |||
| AnfAlgo::SetOutputInferTypeAndShape({dtype}, {dst_shape}, transpose.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{1, 0}), transpose); | |||
| make_tuple_inputs.push_back(transpose); | |||
| } else { | |||
| make_tuple_inputs.push_back(tuple_getitem); | |||
| } | |||
| } | |||
| new_node = func_graph->NewCNode(make_tuple_inputs); | |||
| } | |||
| return new_node; | |||
| } | |||
| const AnfNodePtr InsertTranspose::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||
| CNodePtr new_node = nullptr; | |||
| if (op_name == kBasicLSTMCellInputGradOpName || op_name == kBasicLSTMCellWeightGradOpName) { | |||
| new_node = Insert(func_graph, cnode, op_name); | |||
| } | |||
| return new_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_BASICLSTM_OP_H | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_BASICLSTM_OP_H | |||
| #include <string> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class InsertTranspose : public PatternProcessPass { | |||
| public: | |||
| explicit InsertTranspose(bool multigraph = true) | |||
| : PatternProcessPass("insert_transpose_for_basiclstm_op", multigraph) {} | |||
| ~InsertTranspose() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_BASICLSTM_OP_H | |||
| @@ -24,6 +24,7 @@ namespace opt { | |||
| const std::set<std::pair<string, string>> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW}, | |||
| {kOpFormat_NCHW, kOpFormat_C1HWNCoC0}, | |||
| {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, | |||
| {kOpFormat_DEFAULT, kOpFormat_FRACTAL_ZN_LSTM}, | |||
| {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; | |||
| bool TransDataSplit::Run(const FuncGraphPtr &func_graph) { | |||
| @@ -64,12 +65,13 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n | |||
| AnfNodePtr new_transdata_node = nullptr; | |||
| AnfNodePtr new_transpose_node = nullptr; | |||
| AnfNodePtr new_replace_node = nullptr; | |||
| auto padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); | |||
| // if output_format=default transdata need split transdata->transpose else transpose->transdata | |||
| if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { | |||
| // trans input_format to hwcn | |||
| new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_, | |||
| false, prim::KPrimTransData->name()); | |||
| RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node); | |||
| RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node, padding_axis); | |||
| // trans hwcn to default_format | |||
| new_transpose_node = | |||
| NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name()); | |||
| @@ -86,7 +88,7 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n | |||
| // trans hwcn to output_format | |||
| new_transdata_node = | |||
| NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name()); | |||
| RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node); | |||
| RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node, padding_axis); | |||
| new_transdata_node->set_abstract(node->abstract()); | |||
| new_replace_node = new_transdata_node; | |||
| } | |||
| @@ -56,7 +56,7 @@ inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, | |||
| template <typename T> | |||
| T DivCeil(T n1, T n2) { | |||
| if (n2 != 0) { | |||
| return (n1 - 1) / n2 + 1; | |||
| return (n1 + n2 - 1) / n2; | |||
| } | |||
| return 0; | |||
| } | |||
| @@ -444,6 +444,17 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s | |||
| device_shape.push_back(kCubeSize); | |||
| device_shape.push_back(kCubeSize); | |||
| return device_shape; | |||
| } else if (format == kOpFormat_FRACTAL_ZN_LSTM) { | |||
| const size_t c0 = 4; | |||
| const size_t h = shape.at(kN) / c0; | |||
| const size_t i = shape.at(kC) - h; | |||
| const size_t first = DivCeil(i, kCubeSize) + DivCeil(h, kCubeSize); | |||
| const size_t second = c0 * DivCeil(h, kCubeSize); | |||
| device_shape.push_back(first); | |||
| device_shape.push_back(second); | |||
| device_shape.push_back(kCubeSize); | |||
| device_shape.push_back(kCubeSize); | |||
| return device_shape; | |||
| } | |||
| if (shape.size() != kNchwDims) { | |||
| MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; | |||
| @@ -196,6 +196,9 @@ constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; | |||
| constexpr auto kTensorAddOpName = "TensorAdd"; | |||
| constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum"; | |||
| constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum"; | |||
| constexpr auto kBasicLSTMCellWeightGradOpName = "BasicLSTMCellWeightGrad"; | |||
| constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad"; | |||
| constexpr auto kBasicLSTMCellOpName = "BasicLSTMCell"; | |||
| // attr key name | |||
| constexpr auto kAttrInputNames = "input_names"; | |||
| @@ -324,10 +327,11 @@ constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; | |||
| constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; | |||
| constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04"; | |||
| constexpr auto kOpFormat_NDHWC = "NDHWC"; | |||
| constexpr auto kOpFormat_FRACTAL_ZN_LSTM = "FRACTAL_ZN_LSTM"; | |||
| const std::set<std::string> kOpFormatList = { | |||
| kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, | |||
| kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, | |||
| kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC}; | |||
| kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, | |||
| kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, | |||
| kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM}; | |||
| const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; | |||
| const std::set<std::string> kOptOperatorSet = { | |||
| kMomentumOpName, | |||
| @@ -353,9 +357,9 @@ const std::set<std::string> kOptOperatorSet = { | |||
| kPullOpName, | |||
| }; | |||
| const std::set<std::string> kHWSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, | |||
| kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, | |||
| kOpFormat_FRACTAL_Z_C04}; | |||
| const std::set<std::string> kHWSpecialFormatSet = { | |||
| kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, | |||
| kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM}; | |||
| const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; | |||
| @@ -30,7 +30,7 @@ basic_lstm_cell_op_info = TBERegOp("BasicLSTMCell") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "h", False, "required", "all") \ | |||
| .input(2, "c", False, "required", "all") \ | |||
| .input(3, "w", False, "required", "all") \ | |||
| .input(3, "w", False, "required", "all", reshape_type="CN") \ | |||
| .input(4, "b", False, "required", "all") \ | |||
| .input(5, "mask", False, "optional", "all") \ | |||
| .output(0, "ct", False, "required", "all") \ | |||
| @@ -40,11 +40,11 @@ basic_lstm_cell_op_info = TBERegOp("BasicLSTMCell") \ | |||
| .output(4, "ft", False, "optional", "all") \ | |||
| .output(5, "ot", False, "optional", "all") \ | |||
| .output(6, "tanhct", False, "optional", "all") \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracZ, | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracZNLSTM, | |||
| DataType.F32_Default, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ, | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZNLSTM, | |||
| DataType.F16_Default, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ) \ | |||
| @@ -25,7 +25,7 @@ basic_lstm_cell_input_grad_op_info = TBERegOp("BasicLSTMCellInputGrad") \ | |||
| .attr("keep_prob", "optional", "float", "all") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "dgate", False, "required", "all") \ | |||
| .input(1, "w", False, "required", "all") \ | |||
| .input(1, "w", False, "required", "all", reshape_type="NC") \ | |||
| .input(2, "dropout_mask", False, "optional", "all") \ | |||
| .output(0, "dxt", False, "required", "all") \ | |||
| .output(1, "dht", False, "required", "all") \ | |||
| @@ -26,7 +26,7 @@ basic_lstm_cell_weight_grad_op_info = TBERegOp("BasicLSTMCellWeightGrad") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "h", False, "required", "all") \ | |||
| .input(2, "dgate", False, "required", "all") \ | |||
| .output(0, "dw", False, "required", "all") \ | |||
| .output(0, "dw", False, "required", "all", reshape_type="CN") \ | |||
| .output(1, "db", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ, | |||
| DataType.F32_Default) \ | |||
| @@ -129,6 +129,10 @@ trans_data_op_info = TBERegOp("TransData") \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_HWCN) \ | |||
| .dtype_format(DataType.F16_HWCN, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F32_HWCN, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F16_HWCN, DataType.F16_FracZNLSTM) \ | |||
| .dtype_format(DataType.F32_HWCN, DataType.F32_FracZNLSTM) \ | |||
| .dtype_format(DataType.F16_FracZNLSTM, DataType.F16_HWCN) \ | |||
| .dtype_format(DataType.F32_FracZNLSTM, DataType.F32_HWCN) \ | |||
| .get_op_info() | |||
| @@ -619,6 +619,7 @@ class DataType: | |||
| F16_NHWC = ("float16", "NHWC") | |||
| F16_HWCN = ("float16", "HWCN") | |||
| F16_NDHWC = ("float16", "NDHWC") | |||
| F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM") | |||
| F32_None = ("float32", "") | |||
| F32_Default = ("float32", "DefaultFormat") | |||
| @@ -630,6 +631,7 @@ class DataType: | |||
| F32_NHWC = ("float32", "NHWC") | |||
| F32_HWCN = ("float32", "HWCN") | |||
| F32_NDHWC = ("float32", "NDHWC") | |||
| F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM") | |||
| F64_None = ("float64", "") | |||
| F64_Default = ("float64", "DefaultFormat") | |||