| @@ -137,6 +137,7 @@ static std::map<string, string> tbe_func_adapter_map = { | |||||
| {"histogram_fixed_width", "histogram_fixed_width_d"}, | {"histogram_fixed_width", "histogram_fixed_width_d"}, | ||||
| {"broadcast_to", "broadcast_to_d"}, | {"broadcast_to", "broadcast_to_d"}, | ||||
| {"inplace_update", "inplace_update_d"}, | {"inplace_update", "inplace_update_d"}, | ||||
| {"i_fmr", "ifmr"}, | |||||
| {"matrix_diag", "matrix_diag_d"}, | {"matrix_diag", "matrix_diag_d"}, | ||||
| {"matrix_diag_part", "matrix_diag_part_d"}, | {"matrix_diag_part", "matrix_diag_part_d"}, | ||||
| {"matrix_set_diag", "matrix_set_diag_d"}}; | {"matrix_set_diag", "matrix_set_diag_d"}}; | ||||
| @@ -310,3 +310,4 @@ from .population_count import _population_count_tbe | |||||
| from .parallel_concat import _parallel_concat_tbe | from .parallel_concat import _parallel_concat_tbe | ||||
| from .adam_apply_one_assign import _adam_apply_one_assign_tbe | from .adam_apply_one_assign import _adam_apply_one_assign_tbe | ||||
| from .adam_apply_one_with_decay_assign import _adam_apply_one_with_decay_assign_tbe | from .adam_apply_one_with_decay_assign import _adam_apply_one_with_decay_assign_tbe | ||||
| from .ifmr import _ifmr_tbe | |||||
| @@ -0,0 +1,47 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """IFMR op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| ifmr_op_info = TBERegOp("IFMR") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("ifmr.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("ifmr") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("min_percentile", "required", "float", "all") \ | |||||
| .attr("max_percentile", "required", "float", "all") \ | |||||
| .attr("search_range", "required", "listFloat", "all") \ | |||||
| .attr("search_step", "required", "float", "all") \ | |||||
| .attr("with_offset", "required", "bool", "all") \ | |||||
| .input(0, "data", False, "required", "all") \ | |||||
| .input(1, "data_min", False, "required", "all") \ | |||||
| .input(2, "data_max", False, "required", "all") \ | |||||
| .input(3, "cumsum", False, "required", "all") \ | |||||
| .output(0, "scale", False, "required", "all") \ | |||||
| .output(1, "offset", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, | |||||
| DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(ifmr_op_info) | |||||
| def _ifmr_tbe(): | |||||
| """IFMR TBE register""" | |||||
| return | |||||
| @@ -53,7 +53,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||||
| NPUAllocFloatStatus, NPUClearFloatStatus, | NPUAllocFloatStatus, NPUClearFloatStatus, | ||||
| NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, | NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, | ||||
| Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, | Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, | ||||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, | |||||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, IFMR, | |||||
| Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) | Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) | ||||
| from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, | from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, | ||||
| @@ -97,6 +97,7 @@ __all__ = [ | |||||
| 'EditDistance', | 'EditDistance', | ||||
| 'CropAndResize', | 'CropAndResize', | ||||
| 'TensorAdd', | 'TensorAdd', | ||||
| 'IFMR', | |||||
| 'Argmax', | 'Argmax', | ||||
| 'Argmin', | 'Argmin', | ||||
| 'ArgMaxWithValue', | 'ArgMaxWithValue', | ||||
| @@ -3514,3 +3514,64 @@ class Eps(PrimitiveWithInfer): | |||||
| 'dtype': input_x['dtype'], | 'dtype': input_x['dtype'], | ||||
| } | } | ||||
| return out | return out | ||||
| class IFMR(PrimitiveWithInfer): | |||||
| """ | |||||
| The TFMR(Input Feature Map Reconstruction). | |||||
| Args: | |||||
| min_percentile (float): Min init percentile. | |||||
| max_percentile (float): Max init percentile. | |||||
| search_range Union[list(float), tuple(float)]: Range of searching. | |||||
| search_step (float): Step size of searching. | |||||
| with_offset (bool): Whether using offset. | |||||
| Inputs: | |||||
| - **data** (Tensor) - A Tensor of feature map. With float16 or float32 data type. | |||||
| - **data_min** (Tensor) - A Tensor of min value of feature map, the shape is :math:`(1)`. | |||||
| With float16 or float32 data type. | |||||
| - **data_max** (Tensor) - A Tensor of max value of feature map, the shape is :math:`(1)`. | |||||
| With float16 or float32 data type. | |||||
| - **cumsum** (Tensor) - A `1-D` Tensor of cumsum bin of data. With int32 data type. | |||||
| Outputs: | |||||
| - **scale** (Tensor) - A tensor of optimal scale, the shape is :math:`(1)`. Data dtype is float32. | |||||
| - **offset** (Tensor) - A tensor of optimal offset, the shape is :math:`(1)`. Data dtype is float32. | |||||
| Examples: | |||||
| >>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32)) | |||||
| >>> data_min = Tensor([0.1], mstype.float32) | |||||
| >>> data_max = Tensor([0.5], mstype.float32) | |||||
| >>> cumsum = Tensor(np.random.rand(4).astype(np.int32)) | |||||
| >>> ifmr = P.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0), | |||||
| search_step=1.0, with_offset=False) | |||||
| >>> output = ifmr(data, data_min, data_max, cumsum) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, min_percentile, max_percentile, search_range, search_step, with_offset): | |||||
| validator.check_value_type("min_percentile", min_percentile, [float], self.name) | |||||
| validator.check_value_type("max_percentile", max_percentile, [float], self.name) | |||||
| validator.check_value_type("search_range", search_range, [list, tuple], self.name) | |||||
| for item in search_range: | |||||
| validator.check_float_positive("item of search_range", item, self.name) | |||||
| validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], Rel.GE, self.name) | |||||
| validator.check_value_type("search_step", search_step, [float], self.name) | |||||
| validator.check_value_type("offset_flag", with_offset, [bool], self.name) | |||||
| def infer_shape(self, data_shape, data_min_shape, data_max_shape, cumsum_shape): | |||||
| validator.check_integer("dims of data_min", len(data_min_shape), 1, Rel.EQ, self.name) | |||||
| validator.check_integer("data_min[0]", data_min_shape[0], 1, Rel.EQ, self.name) | |||||
| validator.check_integer("dims of data_max", len(data_max_shape), 1, Rel.EQ, self.name) | |||||
| validator.check_integer("data_max[0]", data_max_shape[0], 1, Rel.EQ, self.name) | |||||
| validator.check_integer("dims of cumsum", len(cumsum_shape), 1, Rel.EQ, self.name) | |||||
| return (1,), (1,) | |||||
| def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype): | |||||
| valid_types = [mstype.float32, mstype.float16] | |||||
| validator.check_tensor_type_same({"input_value": data_dtype}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"input_min": data_min_dtype}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"input_max": data_max_dtype}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"input_bins": cumsum_dtype}, [mstype.int32], self.name) | |||||
| return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32) | |||||
| @@ -1275,6 +1275,13 @@ test_case_math_ops = [ | |||||
| 'block': P.Mod(), | 'block': P.Mod(), | ||||
| 'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]], | 'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]], | ||||
| 'desc_bprop': [[2, 3, 4, 5]]}), | 'desc_bprop': [[2, 3, 4, 5]]}), | ||||
| ('IFMR', { | |||||
| 'block': P.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0), | |||||
| search_step=1.0, with_offset=False), | |||||
| 'desc_inputs': [[3, 4, 5], Tensor([0.1], mstype.float32), Tensor([0.9], mstype.float32), | |||||
| Tensor(np.random.rand(4).astype(np.int32))], | |||||
| 'desc_bprop': [], | |||||
| 'skip': ['backward']}), | |||||
| ] | ] | ||||
| test_case_nn_ops = [ | test_case_nn_ops = [ | ||||