From 863c756ad1d71d4d99a54d56d505df4c3e569011 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 28 May 2020 23:34:04 -0400 Subject: [PATCH] add bnll, bninference and bntrainingupdatev3 for vm --- mindspore/ccsrc/kernel/tbe/tbe_adapter.cc | 1 + mindspore/ops/_op_impl/tbe/__init__.py | 1 + .../ops/_op_impl/tbe/bn_training_update_v3.py | 51 +++++++++++++++++++ 3 files changed, 53 insertions(+) create mode 100644 mindspore/ops/_op_impl/tbe/bn_training_update_v3.py diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index a24166c5b5..2af70bd44b 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -55,6 +55,7 @@ static std::map tbe_func_adapter_map = { {"b_n_training_reduce", "bn_training_reduce"}, {"b_n_training_update", "bn_training_update"}, {"b_n_training_update_v2", "bn_training_update_v2"}, + {"b_n_training_update_v3", "bn_training_update_v3"}, {"b_n_training_reduce_grad", "bn_training_reduce_grad"}, {"b_n_training_update_grad", "bn_training_update_grad"}, {"b_n_infer", "bn_infer"}, diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 7a404cde98..7f23d618d2 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -194,6 +194,7 @@ from .sgd import _sgd_tbe from .lars_update import _lars_update_tbe from .arg_min import _arg_min_tbe from .bn_training_update_v2 import _bn_training_update_v2_tbe +from .bn_training_update_v3 import _bn_training_update_v3_tbe from .square_sum_all import _square_sum_all_tbe from .pack import _pack_tbe from .unpack import _unpack_tbe diff --git a/mindspore/ops/_op_impl/tbe/bn_training_update_v3.py b/mindspore/ops/_op_impl/tbe/bn_training_update_v3.py new file mode 100644 index 0000000000..6d69c6e4be --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/bn_training_update_v3.py @@ -0,0 +1,51 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BNTrainingUpdateV3 op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +bn_training_update_v3_op_info = TBERegOp("BNTrainingUpdateV3") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("bn_training_update_v3.so") \ + .compute_cost(10) \ + .kernel_name("bn_training_update_v3") \ + .partial_flag(True) \ + .attr("epsilon", "required", "float", "all") \ + .input(0, "x", False, "required", "all", reshape_type="NC") \ + .input(1, "sum", False, "required", "all") \ + .input(2, "square_sum", False, "required", "all") \ + .input(3, "scale", False, "required", "all") \ + .input(4, "offset", False, "required", "all") \ + .output(0, "y", False, "required", "all", reshape_type="NC") \ + .output(1, "batch_mean", False, "required", "all") \ + .output(2, "batch_variance", False, "required", "all") \ + .output(3, "reserve_1", False, "required", "all") \ + .output(4, "reserve_2", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD, + DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(bn_training_update_v3_op_info) +def _bn_training_update_v3_tbe(): + """BNTrainingUpdateV3 TBE register""" + return