Merge pull request !8207 from liangchenghui/add_quant_opstags/v1.1.0
| @@ -176,3 +176,27 @@ def get_bprop_fakequant_with_minmax_per_channel_update(self): | |||
| return zeros_like(x), zeros_like(x_min), zeros_like(x_max) | |||
| return bprop | |||
| @bprop_getters.register(Q.ActsULQ) | |||
| def get_bprop_acts_ulq(self): | |||
| """Grad definition for 'ActsULQ' operation""" | |||
| op = Q.ActsULQInputGrad() | |||
| op1 = Q.ActULQClampMinGrad() | |||
| op2 = Q.ActULQClampMaxGrad() | |||
| def bprop(x, clamp_min, clamp_max, out, dout): | |||
| dx = op(dout[0], out[1], out[2]) | |||
| dx1 = op1(dout[0], out[1], out[3]) | |||
| dx2 = op2(dout[0], out[2], out[3]) | |||
| return (dx, dx1, dx2) | |||
| return bprop | |||
| @bprop_getters.register(Q.WtsARQ) | |||
| def get_bprop_wts_arq(self): | |||
| """Grad definition for 'WtsArq' operation""" | |||
| def bprop(w, w_min, w_max, out, dout): | |||
| return (dout, zeros_like(w_min), zeros_like(w_max)) | |||
| return bprop | |||
| @@ -326,6 +326,11 @@ from .parallel_concat import _parallel_concat_tbe | |||
| from .adam_apply_one_assign import _adam_apply_one_assign_tbe | |||
| from .adam_apply_one_with_decay_assign import _adam_apply_one_with_decay_assign_tbe | |||
| from .ifmr import _ifmr_tbe | |||
| from .acts_ulq import _acts_ulq_tbe | |||
| from .acts_ulq_input_grad import _acts_ulq_input_grad_tbe | |||
| from .act_ulq_clamp_min_grad import _act_ulq_clamp_min_grad_tbe | |||
| from .act_ulq_clamp_max_grad import _act_ulq_clamp_max_grad_tbe | |||
| from .wts_arq import _wts_arq_tbe | |||
| from .fake_quant_with_min_max_vars import _fake_quant_with_min_max_vars_tbe | |||
| from .fake_quant_with_min_max_vars_gradient import _fake_quant_with_min_max_vars_gradient_tbe | |||
| from .fake_quant_with_min_max_vars_per_channel import _fake_quant_with_min_max_vars_per_channel_tbe | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ActULQClampMaxGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| act_ulq_clamp_max_grad_op_info = TBERegOp("ActULQClampMaxGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("act_ulq_clamp_max_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("act_ulq_clamp_max_grad") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "input_x", False, "required", "all") \ | |||
| .input(1, "input_y", False, "required", "all") \ | |||
| .input(2, "input_z", False, "required", "all") \ | |||
| .output(0, "output", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.BOOL_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.BOOL_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(act_ulq_clamp_max_grad_op_info) | |||
| def _act_ulq_clamp_max_grad_tbe(): | |||
| """ActULQClampMaxGrad TBE register""" | |||
| return | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ActULQClampMinGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| act_ulq_clamp_min_grad_op_info = TBERegOp("ActULQClampMinGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("act_ulq_clamp_min_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("act_ulq_clamp_min_grad") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "input_x", False, "required", "all") \ | |||
| .input(1, "input_y", False, "required", "all") \ | |||
| .input(2, "input_z", False, "required", "all") \ | |||
| .output(0, "output", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.BOOL_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.BOOL_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(act_ulq_clamp_min_grad_op_info) | |||
| def _act_ulq_clamp_min_grad_tbe(): | |||
| """ActULQClampMinGrad TBE register""" | |||
| return | |||
| @@ -0,0 +1,45 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ActsULQ op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| acts_ulq_op_info = TBERegOp("ActsULQ") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("acts_ulq.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("acts_ulq") \ | |||
| .partial_flag(True) \ | |||
| .attr("fixed_min", "optional", "bool", "all") \ | |||
| .attr("num_bits", "optional", "int", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "clamp_min", False, "required", "all") \ | |||
| .input(2, "clamp_max", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .output(1, "clamp_min_mask", False, "required", "all") \ | |||
| .output(2, "clamp_max_mask", False, "required", "all") \ | |||
| .output(3, "x_clamped_loss", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.BOOL_Default, DataType.BOOL_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.BOOL_Default, DataType.BOOL_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(acts_ulq_op_info) | |||
| def _acts_ulq_tbe(): | |||
| """ActsULQ TBE register""" | |||
| return | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ActsULQInputGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| acts_ulq_input_grad_op_info = TBERegOp("ActsULQInputGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("acts_ulq_input_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("acts_ulq_input_grad") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "y_grad", False, "required", "all") \ | |||
| .input(1, "clamp_min_mask", False, "required", "all") \ | |||
| .input(2, "clamp_max_mask", False, "required", "all") \ | |||
| .output(0, "x_grad", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.BOOL_Default, DataType.BOOL_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.BOOL_Default, DataType.BOOL_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(acts_ulq_input_grad_op_info) | |||
| def _acts_ulq_input_grad_tbe(): | |||
| """ActsULQInputGrad TBE register""" | |||
| return | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """WtsARQ op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| wts_arq_op_info = TBERegOp("WtsARQ") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("wts_arq.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("wts_arq") \ | |||
| .partial_flag(True) \ | |||
| .attr("num_bits", "optional", "int", "all") \ | |||
| .attr("offset_flag", "optional", "bool", "all") \ | |||
| .input(0, "w", False, "required", "all") \ | |||
| .input(1, "w_min", False, "required", "all") \ | |||
| .input(2, "w_max", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(wts_arq_op_info) | |||
| def _wts_arq_tbe(): | |||
| """WtsARQ TBE register""" | |||
| return | |||
| @@ -1182,3 +1182,205 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer): | |||
| def infer_dtype(self, dout_type, x_type): | |||
| validator.check("dout type", dout_type, "x type", x_type) | |||
| return dout_type, dout_type | |||
| class ActsULQ(PrimitiveWithInfer): | |||
| """ | |||
| The ActsULQ(Activation universal learnable quantization). | |||
| Args: | |||
| fixed_min (bool): whether fix clamp min to zero. | |||
| num_bits (int): The bits num used for quantize. | |||
| Inputs: | |||
| - **x** (Tensor) - A Tensor of feature map. With float16 or float32 data type. | |||
| - **clamp_min** (Tensor) - A Tensor of clamp min with the same type as x. | |||
| - **clamp_max** (Tensor) - A Tensor of clamp max with the same type as x. | |||
| Outputs: | |||
| - **y** (Tensor) - A tensor of fake quant of feature map with the same type as `w`. | |||
| - **clamp_min** (Tensor) - A tensor of boolean masks if data in feature map >= clamp_min. | |||
| - **clamp_max** (Tensor) - A tensor of boolean masks if data in feature map <= clamp_max. | |||
| - **x_clamped_loss** (Tensor) - A tensor of clamped loss. | |||
| Examples: | |||
| >>> data_type = np.float32 | |||
| >>> x= np.random.uniform(-10, 10, (32, 120)).astype(data_type) | |||
| >>> clamp_max = 0.7 * np.max(x) | |||
| >>> clamp_min = 0.7 * np.min(x) | |||
| >>> clamp_max = np.array([clamp_max], dtype=data_type) | |||
| >>> clamp_min = np.array([clamp_min], dtype=data_type) | |||
| >>> acts_ulq = Q.ActsULQ(fixed_mini=True, num_bits=8) | |||
| >>> quant_x, clamp_min_mask, clamp_max_mask, x_clamped_loss = acts_ulq(Tensor(x), Tensor( clamp_min), | |||
| Tensor(clamp_max)) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, fixed_min=False, num_bits=8): | |||
| validator.check_value_type("fixed_min", fixed_min, [bool], self.name) | |||
| validator.check_value_type("num_bits", num_bits, [int], self.name) | |||
| validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name) | |||
| def infer_shape(self, x_shape, clamp_min_shape, clamp_max_shape): | |||
| """infer shape of primitive""" | |||
| validator.check_int(len(clamp_min_shape), len(x_shape), Rel.EQ, "dims of clamp_min", self.name) | |||
| validator.check_int(len(clamp_max_shape), len(x_shape), Rel.EQ, "dims of clamp_max", self.name) | |||
| x_shape_len = len(x_shape) | |||
| for i in range(x_shape_len): | |||
| validator.check_int(clamp_min_shape[i], 1, Rel.EQ, "dims of clamp_min", self.name) | |||
| validator.check_int(clamp_max_shape[i], 1, Rel.EQ, "dims of clamp_max", self.name) | |||
| return x_shape, x_shape, x_shape, x_shape | |||
| def infer_dtype(self, x_dtype, clamp_min_dtype, clamp_max_dtype): | |||
| """infer dtype of primitive""" | |||
| valid_types = [mstype.float32, mstype.float16] | |||
| validator.check_tensor_type_same({"x": x_dtype}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"clamp_min": clamp_min_dtype}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"clamp_max": clamp_max_dtype}, valid_types, self.name) | |||
| return x_dtype, mstype.bool_, mstype.bool_, x_dtype | |||
| class ActsULQInputGrad(PrimitiveWithInfer): | |||
| """ | |||
| The ActsULQInputGrad(grad of ActsULQ). | |||
| Inputs: | |||
| - **y_grad** (Tensor) - A Tensor of grad. With float16 or float32 data type. | |||
| Outputs: | |||
| - **x_grad** (Tensor) - A tensor of data grad with the same type as `y_grad`. | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| pass | |||
| def infer_shape(self, y_grad_shape, clamp_min_mask_shape, clamp_max_mask_shape): | |||
| return y_grad_shape | |||
| def infer_dtype(self, y_grad_type, clamp_min_mask_type, clamp_max_mask_type): | |||
| valid_types = [mstype.float32, mstype.float16] | |||
| validator.check_tensor_type_same({"y_grad": y_grad_type}, valid_types, self.name) | |||
| return y_grad_type | |||
| class ActULQClampMinGrad(PrimitiveWithInfer): | |||
| """ | |||
| The ActULQClampMinGrad(Activation Universal Linear Quantization on Clamp Minimum Gradient) | |||
| Inputs: | |||
| - **y_grad** (Tensor) - A tensor of gradient, with float16 or float32 type. | |||
| - **clamp_min_mask** - A tensor of mask, only support int8 type. | |||
| - **x_clamped_loss** - A tensor of loss, with the same type as "y_grad". | |||
| Outputs: | |||
| - **clamp_min_grad** - A tensor of clamp minimum gradient, with the same type as "y_grad". | |||
| The length of tensor is 1. | |||
| Examples: | |||
| >>> data_type = np.float32 | |||
| >>> y_grad = np.random.uniform(-10, 10, (32, 120)).astype(data_type) | |||
| >>> clamp_min_mask = np.where(np.random.rand(32, 120) >= 0.5, 1, 0) | |||
| >>> x_clamped_loss = np.random.uniform(-10, 10, (32, 120)).astype(data_type) | |||
| >>> act_ulq_clamp_min_grad = Q.ActULQClampMinGrad() | |||
| >>> clamp_min_grad = act_ulq_clamp_min_grad(Tensor(y_grad), Tensor(clamp_min_mask, mindspore.bool_), | |||
| Tensor(x_clamped_loss)) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| pass | |||
| def infer_shape(self, input_x, input_y, input_z): | |||
| input_x_len = len(input_x) | |||
| output_shape = [] | |||
| for _ in range(input_x_len): | |||
| output_shape.append(1) | |||
| return tuple(output_shape) | |||
| def infer_dtype(self, input_x, input_y, input_z): | |||
| return input_x | |||
| class ActULQClampMaxGrad(PrimitiveWithInfer): | |||
| """ | |||
| The ActULQClampMaxGrad(Activation Universal Linear Quantization on Clamp Maximum Gradient) | |||
| Inputs: | |||
| - **y_grad** (Tensor) - A tensor of gradient, with float16 or float32 type. | |||
| - **clamp_max_mask** - A tensor of mask, only support int8 type. | |||
| - **x_clamped_loss** - A tensor of loss, with the same type as "y_grad". | |||
| Outputs: | |||
| - **clamp_max_grad** - A tensor of clamp maximum gradient, with the same type as "y_grad". | |||
| The length of tensor is 1. | |||
| Examples: | |||
| >>> data_type = np.float32 | |||
| >>> y_grad = np.random.uniform(-10, 10, (32, 120)).astype(data_type) | |||
| >>> clamp_max_mask = np.where(np.random.rand(32, 120) >= 0.5, 1, 0) | |||
| >>> x_clamped_loss = np.random.uniform(-10, 10, (32, 120)).astype(data_type) | |||
| >>> act_ulq_clamp_max_grad = Q.ActULQClampMaxGrad() | |||
| >>> clamp_max_grad = act_ulq_clamp_max_grad(Tensor(y_grad), Tensor(clamp_max_mask, mindspore.bool_), | |||
| Tensor(x_clamped_loss)) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| pass | |||
| def infer_shape(self, input_x, input_y, input_z): | |||
| input_x_len = len(input_x) | |||
| output_shape = [] | |||
| for _ in range(input_x_len): | |||
| output_shape.append(1) | |||
| return tuple(output_shape) | |||
| def infer_dtype(self, input_x, input_y, input_z): | |||
| return input_x | |||
| class WtsARQ(PrimitiveWithInfer): | |||
| """ | |||
| The WtsARQ(Weights Adaptive Range Quantization). | |||
| Args: | |||
| axes (list): Specify channels for ARQ algorithm. | |||
| num_bits (int): The bits num used for quantize. | |||
| offset_flag (bool): Whether use offset for quantize. | |||
| Inputs: | |||
| - **w** (Tensor) - A Tensor of weights. With float16 or float32 data type. | |||
| Outputs: | |||
| - **scale** (Tensor) - A tensor of optimal scale, has the same type as `w`. | |||
| - **offset** (Tensor) - A tensor of optimal offset, has the same type as `w`. | |||
| - If axis is [], | |||
| the shape of scale and offset is :math:`(1, )`. | |||
| - If axis is [0], | |||
| the shape of scale and offset is :math:`(w_1, )`. | |||
| - If axis is [1], | |||
| the shape of scale and offset is :math:`(w_2, )`. | |||
| - **y** (Tensor) - A tensor of fakequant weights, has the same type and shape as `w`. | |||
| Examples: | |||
| >>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32)) | |||
| >>> wts_arq = Q.WtsARQ(axes=[0], num_bits=8, offset_flag=False) | |||
| >>> scale, offset, y = wts_arq(data) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, num_bits, offset_flag): | |||
| validator.check_value_type("num_bits", num_bits, [int], self.name) | |||
| validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name) | |||
| validator.check_value_type("offset_flag", offset_flag, [bool], self.name) | |||
| def infer_shape(self, w_shape, w_min_shape, w_max_shape): | |||
| validator.check_int(len(w_min_shape), len(w_shape), Rel.EQ, "dims of w_min", self.name) | |||
| validator.check_int(len(w_max_shape), len(w_shape), Rel.EQ, "dims of w_max", self.name) | |||
| return w_shape | |||
| def infer_dtype(self, w_dtype, w_min_dtype, w_max_dtype): | |||
| valid_types = [mstype.float32, mstype.float16] | |||
| validator.check_tensor_type_same({"w": w_dtype}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"w_min": w_min_dtype}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"w_max": w_max_dtype}, valid_types, self.name) | |||
| return w_dtype | |||