| @@ -111,6 +111,7 @@ static std::map<string, string> tbe_func_adapter_map = { | |||
| {"reduce_prod", "reduce_prod_d"}, | |||
| {"a_cos", "acos"}, | |||
| {"a_cos_grad", "acos_grad"}, | |||
| {"histogram_fixed_width", "histogram_fixed_width_d"}, | |||
| {"broadcast_to", "broadcast_to_d"}}; | |||
| void TbeAdapter::NormalizeFuncName(std::string *func_name) { | |||
| @@ -249,3 +249,5 @@ from .fused_mul_add_n_l2loss import _fused_mul_add_n_l2loss_tbe | |||
| from .fused_mul_apply_momentum_extern import _fused_mul_apply_momentum_extern_tbe | |||
| from .lamb_next_right import _lamb_next_right_tbe | |||
| from .sparse_gather_v2 import _sparse_gather_v2_tbe | |||
| from .data_format_dim_map import _data_format_dim_map_tbe | |||
| from .histogram_fixed_width import _histogram_fixed_width_tbe | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """DataFormatDimMap op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| data_format_dim_map_op_info = TBERegOp("DataFormatDimMap") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("data_format_dim_map.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("data_format_dim_map") \ | |||
| .partial_flag(True) \ | |||
| .attr("dst_format", "optional", "str", "all") \ | |||
| .attr("src_format", "optional", "str", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(data_format_dim_map_op_info) | |||
| def _data_format_dim_map_tbe(): | |||
| """DataFormatDimMap TBE register""" | |||
| return | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """HistogramFixedWidth op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| histogram_fixed_width_op_info = TBERegOp("HistogramFixedWidth") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("histogram_fixed_width_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("histogram_fixed_width_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("nbins", "required", "int", "all") \ | |||
| .attr("dtype", "optional", "str", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "range", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(histogram_fixed_width_op_info) | |||
| def _histogram_fixed_width_tbe(): | |||
| """HistogramFixedWidth TBE register""" | |||
| return | |||
| @@ -49,7 +49,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2 | |||
| Minimum, Mul, Neg, NMSWithMask, NotEqual, | |||
| NPUAllocFloatStatus, NPUClearFloatStatus, | |||
| NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, | |||
| Reciprocal, CumSum, | |||
| Reciprocal, CumSum, HistogramFixedWidth, | |||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, | |||
| Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh) | |||
| @@ -62,7 +62,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl | |||
| Gelu, Elu, | |||
| GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, | |||
| LogSoftmax, | |||
| MaxPool, | |||
| MaxPool, DataFormatDimMap, | |||
| AvgPool, Conv2DBackpropInput, ConfusionMulGrad, | |||
| MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | |||
| ResizeBilinear, Sigmoid, | |||
| @@ -207,6 +207,7 @@ __all__ = [ | |||
| 'ScatterNd', | |||
| 'ScatterMax', | |||
| 'ResizeNearestNeighbor', | |||
| 'HistogramFixedWidth', | |||
| 'Pad', | |||
| 'MirrorPad', | |||
| 'GatherNd', | |||
| @@ -298,7 +299,8 @@ __all__ = [ | |||
| "BasicLSTMCell", | |||
| "ConfusionMatrix", | |||
| "BroadcastTo", | |||
| "Range" | |||
| "Range", | |||
| "DataFormatDimMap" | |||
| ] | |||
| __all__.extend(_quant_ops.__all__) | |||
| @@ -1043,6 +1043,50 @@ class Expm1(PrimitiveWithInfer): | |||
| return x_type | |||
| class HistogramFixedWidth(PrimitiveWithInfer): | |||
| """ | |||
| Returns a rank 1 histogram counting the number of entries in values that fall into every bin. The bins are equal | |||
| width and determined by the arguments range and nbins. | |||
| Args: | |||
| dtype (string): An optional attribute. Must be one of the following types: "int32", "int64". Default: "int32". | |||
| nbins (Tensor): Number of histogram bins, the type is int32. | |||
| Inputs: | |||
| - **x** (Tensor) - Numeric Tensor. Must be one of the following types: int32, float32, float16. | |||
| - **range** (Tensor) - Must have the same type as x. Shape [2] Tensor of same dtype as x. | |||
| x <= range[0] will be mapped to hist[0], x >= range[1] will be mapped to hist[-1]. | |||
| Outputs: | |||
| Tensor, the type is int32. | |||
| Examples: | |||
| >>> x = Tensor([-1.0, 0.0, 1.5, 2.0, 5.0, 15], mindspore.float16) | |||
| >>> range = Tensor([0.0, 5.0], mindspore.float16) | |||
| >>> hist = P.HistogramFixedWidth(5) | |||
| >>> hist(x, range) | |||
| [2 1 1 0 2] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, nbins, dtype='int32'): | |||
| self.nbins = validator.check_value_type("nbins", nbins, [int], self.name) | |||
| valid_values = ['int32', 'int64'] | |||
| self.dtype = validator.check_string("dtype", dtype, valid_values, self.name) | |||
| self.init_prim_io_names(inputs=['x', 'range'], outputs=['y']) | |||
| def infer_shape(self, x_shape, range_shape): | |||
| return (self.nbins,) | |||
| def infer_dtype(self, x_dtype, range_dtype): | |||
| validator.check_subclass("x", x_dtype, mstype.tensor, self.name) | |||
| valid_types = (mstype.float16, mstype.float32, mstype.int32) | |||
| validator.check_tensor_type_same({"x": x_dtype}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"range": range_dtype}, valid_types, self.name) | |||
| y_dtype = mstype.int32 | |||
| return y_dtype | |||
| class Log(PrimitiveWithInfer): | |||
| """ | |||
| Returns the natural logarithm of a tensor element-wise. | |||
| @@ -1613,6 +1613,45 @@ class L2Loss(PrimitiveWithInfer): | |||
| return x_type | |||
| class DataFormatDimMap(PrimitiveWithInfer): | |||
| """ | |||
| Returns the dimension index in the destination data format given the one in the source data format. | |||
| Args: | |||
| src_format (string): An optional value for source data format. Default: 'NHWC'. | |||
| dst_format (string): An optional value for destination data format. Default: 'NCHW'. | |||
| Inputs: | |||
| - **input_x** (Tensor) - A Tensor with each element as a dimension index in source data format. | |||
| Must be in the range [-4, 4). It's type is int32. | |||
| Outputs: | |||
| Tensor, has the same type as the `input_x`. | |||
| Examples: | |||
| >>> x = Tensor([0, 1, 2, 3], mindspore.int32) | |||
| >>> dfdm = P.DataFormatDimMap() | |||
| >>> dfdm(x) | |||
| [0 3 1 2] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, src_format='NHWC', dst_format='NCHW'): | |||
| valid_values = ['NHWC', 'NCHW'] | |||
| self.src_format = validator.check_string("src_format", src_format, valid_values, self.name) | |||
| self.dst_format = validator.check_string("dst_format", dst_format, valid_values, self.name) | |||
| self.init_prim_io_names(inputs=['input_x'], outputs=['output']) | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_subclass("x", x_type, mstype.tensor, self.name) | |||
| valid_types = [mstype.int32] | |||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | |||
| return x_type | |||
| class SGD(PrimitiveWithInfer): | |||
| """ | |||
| Computes stochastic gradient descent (optionally with momentum). | |||
| @@ -3735,7 +3774,7 @@ class BasicLSTMCell(PrimitiveWithInfer): | |||
| validator.check_integer("b rank", len(b_shape), 4, Rel.EQ, self.name) | |||
| validator.check("w_shape[0]", w_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) | |||
| validator.check("w_shape[1]", w_shape[1], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name) | |||
| validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4*h_shape[1], Rel.EQ, self.name) | |||
| validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) | |||
| ct_shape = c_shape | |||
| ht_shape = h_shape | |||
| it_shape = h_shape | |||
| @@ -764,6 +764,11 @@ test_case_math_ops = [ | |||
| 'desc_inputs': [Tensor(np.array([[24, 4, 13, 9], [1, 5, 10, 8]]).astype(np.int16))], | |||
| 'desc_bprop': [], | |||
| 'skip': ['backward']}), | |||
| ('HistogramFixedWidth', { | |||
| 'block': P.HistogramFixedWidth(5), | |||
| 'desc_inputs': [Tensor([-1.0, 0.0, 1.5, 2.0, 5.0, 15], mstype.float16), Tensor([0.0, 5.0], mstype.float16)], | |||
| 'desc_bprop': [], | |||
| 'skip': ['backward']}), | |||
| ] | |||
| test_case_nn_ops = [ | |||
| @@ -1203,6 +1208,11 @@ test_case_nn_ops = [ | |||
| Tensor([[0.5, 0.4], [0.6, 0.1]], mstype.float32), Tensor([1, 1], mstype.int32)], | |||
| 'desc_bprop': [Tensor([[0.7, 0.2], [0.1, 0.07]], mstype.float32)], | |||
| 'skip': ['backward']}), | |||
| ('DataFormatDimMap', { | |||
| 'block': P.DataFormatDimMap(), | |||
| 'desc_inputs': [Tensor([0, 1, 2, 3], mstype.int32)], | |||
| 'desc_bprop': [], | |||
| 'skip': ['backward']}), | |||
| ] | |||
| test_case_array_ops = [ | |||