Merge pull request !1863 from zhaozhenlong/op/broadcast-to-d-vmtags/v0.5.0-beta
| @@ -105,7 +105,8 @@ static std::map<string, string> tbe_func_adapter_map = { | |||||
| {"unsorted_segment_min", "unsorted_segment_min_d"}, | {"unsorted_segment_min", "unsorted_segment_min_d"}, | ||||
| {"reduce_prod", "reduce_prod_d"}, | {"reduce_prod", "reduce_prod_d"}, | ||||
| {"a_cos", "acos"}, | {"a_cos", "acos"}, | ||||
| {"a_cos_grad", "acos_grad"}}; | |||||
| {"a_cos_grad", "acos_grad"}, | |||||
| {"broadcast_to", "broadcast_to_d"}}; | |||||
| void TbeAdapter::NormalizeFuncName(std::string *func_name) { | void TbeAdapter::NormalizeFuncName(std::string *func_name) { | ||||
| if (func_name == nullptr) { | if (func_name == nullptr) { | ||||
| @@ -139,7 +140,7 @@ void TbeAdapter::NormalizeFuncName(std::string *func_name) { | |||||
| *func_name = name_tmp; | *func_name = name_tmp; | ||||
| auto iter = tbe_func_adapter_map.find(*func_name); | auto iter = tbe_func_adapter_map.find(*func_name); | ||||
| if (iter != tbe_func_adapter_map.end()) { | if (iter != tbe_func_adapter_map.end()) { | ||||
| MS_LOG(INFO) << "map actual op from me " << func_name << "to tbe op" << iter->second; | |||||
| MS_LOG(INFO) << "map actual op from me " << *func_name << " to tbe op" << iter->second; | |||||
| *func_name = iter->second; | *func_name = iter->second; | ||||
| } | } | ||||
| } | } | ||||
| @@ -175,7 +175,7 @@ class FakeQuantWithMinMaxAscend(Cell): | |||||
| else: | else: | ||||
| quant_fun = P.FakeQuantPerLayer | quant_fun = P.FakeQuantPerLayer | ||||
| ema_fun = P.FakeQuantMinMaxPerLayerUpdate | ema_fun = P.FakeQuantMinMaxPerLayerUpdate | ||||
| self.fake_quant = quant_fun(num_bits=self.num_bits, | self.fake_quant = quant_fun(num_bits=self.num_bits, | ||||
| ema=self.ema, | ema=self.ema, | ||||
| ema_decay=self.ema_decay, | ema_decay=self.ema_decay, | ||||
| @@ -272,7 +272,7 @@ class FakeQuantWithMinMaxGPU(Cell): | |||||
| 0, self.out_channels)]).astype(np.float32) | 0, self.out_channels)]).astype(np.float32) | ||||
| self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) | self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) | ||||
| self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) | self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) | ||||
| if per_channel: | if per_channel: | ||||
| quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis) | quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis) | ||||
| else: | else: | ||||
| @@ -18,6 +18,7 @@ | |||||
| from .. import operations as P | from .. import operations as P | ||||
| from ..operations import _grad_ops as G | from ..operations import _grad_ops as G | ||||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | from ..composite.multitype_ops.zeros_like_impl import zeros_like | ||||
| from ..functional import broadcast_gradient_args | |||||
| from .. import functional as F | from .. import functional as F | ||||
| from .grad_base import bprop_getters | from .grad_base import bprop_getters | ||||
| from ..primitive import constexpr | from ..primitive import constexpr | ||||
| @@ -580,3 +581,17 @@ def get_bprop_batch_to_space_nd(self): | |||||
| dx = batch_to_space_nd_grad(dout) | dx = batch_to_space_nd_grad(dout) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.BroadcastTo) | |||||
| def get_bprop_broadcast_to(self): | |||||
| """Generate bprop for BroadcastTo""" | |||||
| reduce_keep_dim = P.ReduceSum(keep_dims=True) | |||||
| broadcast_shape = self.shape | |||||
| def bprop(x, out, dout): | |||||
| x_shape = shape_op(x) | |||||
| _, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape) | |||||
| reduced_grad = reduce_keep_dim(dout, reduction_axes) | |||||
| dx = reshape(reduced_grad, x_shape) | |||||
| return (dx,) | |||||
| return bprop | |||||
| @@ -217,9 +217,9 @@ from .bessel_i0e import _bessel_i0e_tbe | |||||
| from .bessel_i1e import _bessel_i1e_tbe | from .bessel_i1e import _bessel_i1e_tbe | ||||
| from .batch_to_space_nd import _batch_to_space_nd_tbe | from .batch_to_space_nd import _batch_to_space_nd_tbe | ||||
| from .space_to_batch_nd import _space_to_batch_nd_tbe | from .space_to_batch_nd import _space_to_batch_nd_tbe | ||||
| from .bitwise_and import bitwise_and_op_info | |||||
| from .bitwise_or import bitwise_or_op_info | |||||
| from .bitwise_xor import bitwise_xor_op_info | |||||
| from .bitwise_and import _bitwise_and_tbe | |||||
| from .bitwise_or import _bitwise_or_tbe | |||||
| from .bitwise_xor import _bitwise_xor_tbe | |||||
| from .reduce_all import _reduce_all_tbe | from .reduce_all import _reduce_all_tbe | ||||
| from .sparse_apply_adagrad import _sparse_apply_adagrad_tbe | from .sparse_apply_adagrad import _sparse_apply_adagrad_tbe | ||||
| from .unsorted_segment_min import _unsorted_segment_min_tbe | from .unsorted_segment_min import _unsorted_segment_min_tbe | ||||
| @@ -238,3 +238,4 @@ from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe | |||||
| from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe | from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe | ||||
| from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe | from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe | ||||
| from .confusion_matrix import _confusion_matrix_tbe | from .confusion_matrix import _confusion_matrix_tbe | ||||
| from .broadcast_to import _broadcast_to_tbe | |||||
| @@ -0,0 +1,40 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """BroadcastTo op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| broadcast_to_op_info = TBERegOp("BroadcastTo") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("broadcast_to_d.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("broadcast_to_d") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("shape", "required", "listInt", "all") \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U16_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(broadcast_to_op_info) | |||||
| def _broadcast_to_tbe(): | |||||
| """BroadcastTo TBE register""" | |||||
| return | |||||
| @@ -30,7 +30,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| Squeeze, StridedSlice, Tile, | Squeeze, StridedSlice, Tile, | ||||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, | Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, | ||||
| UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | ||||
| SpaceToBatchND, BatchToSpaceND) | |||||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo) | |||||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | ||||
| _MirrorOperator, ReduceOp, _VirtualDataset, | _MirrorOperator, ReduceOp, _VirtualDataset, | ||||
| _VirtualDiv, _GetTensorSlice, | _VirtualDiv, _GetTensorSlice, | ||||
| @@ -289,7 +289,8 @@ __all__ = [ | |||||
| "Atan", | "Atan", | ||||
| "Atanh", | "Atanh", | ||||
| "BasicLSTMCell", | "BasicLSTMCell", | ||||
| "ConfusionMatrix" | |||||
| "ConfusionMatrix", | |||||
| "BroadcastTo" | |||||
| ] | ] | ||||
| __all__.extend(_quant_ops.__all__) | __all__.extend(_quant_ops.__all__) | ||||
| @@ -2738,3 +2738,40 @@ class BatchToSpaceND(PrimitiveWithInfer): | |||||
| f'block_shape_prod {block_shape_prod}') | f'block_shape_prod {block_shape_prod}') | ||||
| out_shape[0] = out_shape[0] // block_shape_prod | out_shape[0] = out_shape[0] // block_shape_prod | ||||
| return out_shape | return out_shape | ||||
| class BroadcastTo(PrimitiveWithInfer): | |||||
| """ | |||||
| Broadcasts input tensor to a given shape. | |||||
| Args: | |||||
| shape (tuple): The target shape to broadcast. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The input tensor. | |||||
| Outputs: | |||||
| Tensor, with the given `shape` and the same data type as `input_x`. | |||||
| Examples: | |||||
| >>> shape = (2, 3) | |||||
| >>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32)) | |||||
| >>> broadcast_to = P.BroadcastTo(shape) | |||||
| >>> broadcast_to(input_x) | |||||
| [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, shape): | |||||
| """Init BroadcastTo""" | |||||
| validator.check_value_type("shape", shape, (tuple), self.name) | |||||
| for i in shape: | |||||
| validator.check_integer("shape element", i, 0, Rel.GT, self.name) | |||||
| self.shape = shape | |||||
| def infer_shape(self, x_shape): | |||||
| return self.shape | |||||
| def infer_dtype(self, x_dtype): | |||||
| validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name) | |||||
| return x_dtype | |||||
| @@ -1396,6 +1396,10 @@ test_case_array_ops = [ | |||||
| 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)), | 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)), | ||||
| Tensor(np.array([0, 1, 1]).astype(np.int32))], | Tensor(np.array([0, 1, 1]).astype(np.int32))], | ||||
| 'desc_bprop': [Tensor(np.array([[1, 2, 3], [4, 2, 1]]).astype(np.float32))]}), | 'desc_bprop': [Tensor(np.array([[1, 2, 3], [4, 2, 1]]).astype(np.float32))]}), | ||||
| ('BroadcastTo', { | |||||
| 'block': P.BroadcastTo((2,3)), | |||||
| 'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.float32))], | |||||
| 'desc_bprop': [Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.float32))]}), | |||||
| ] | ] | ||||
| test_case_other_ops = [ | test_case_other_ops = [ | ||||