Merge pull request !3250 from vlne-v1/quant_export_geir_biasadd_after_depth_wisetags/v0.6.0-beta
| @@ -395,7 +395,7 @@ void ExecutorPy::GetGeBackendPolicy() const { | |||||
| bool IsPhaseExportGeir(const std::string &phase_s) { | bool IsPhaseExportGeir(const std::string &phase_s) { | ||||
| auto phase_to_export = "export.geir"; | auto phase_to_export = "export.geir"; | ||||
| return phase_s.rfind(phase_to_export, 0) != std::string::npos; | |||||
| return phase_s.rfind(phase_to_export) != std::string::npos; | |||||
| } | } | ||||
| std::vector<ActionItem> GetPipline(const ResourcePtr &resource, const std::string &phase_s, bool use_vm) { | std::vector<ActionItem> GetPipline(const ResourcePtr &resource, const std::string &phase_s, bool use_vm) { | ||||
| @@ -757,7 +757,7 @@ ATTR_MAP(ExtractImagePatches) = {{"ksizes", ATTR_DESC(ksizes, AnyTraits<int>(), | |||||
| OUTPUT_MAP(ExtractImagePatches) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(ExtractImagePatches) = {{0, OUTPUT_DESC(y)}}; | ||||
| // Conv2D | // Conv2D | ||||
| INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; | |||||
| INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(bias)}}; | |||||
| ATTR_MAP(Conv2D) = { | ATTR_MAP(Conv2D) = { | ||||
| {"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | {"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| {"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | {"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| @@ -794,7 +794,7 @@ ATTR_MAP(Conv2DBackpropFilterD) = { | |||||
| OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}}; | ||||
| // DepthwiseConv2D | // DepthwiseConv2D | ||||
| INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; | |||||
| INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(bias)}}; | |||||
| ATTR_MAP(DepthwiseConv2D) = { | ATTR_MAP(DepthwiseConv2D) = { | ||||
| {"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | {"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| {"pads", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | {"pads", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| @@ -826,7 +826,7 @@ ATTR_MAP(DepthwiseConv2DBackpropFilterD) = { | |||||
| OUTPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{0, OUTPUT_DESC(filter_grad)}}; | OUTPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{0, OUTPUT_DESC(filter_grad)}}; | ||||
| // MatMulV2 | // MatMulV2 | ||||
| INPUT_MAP(MatMulV2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; | |||||
| INPUT_MAP(MatMulV2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(bias)}}; | |||||
| ATTR_MAP(MatMulV2) = {{"transpose_a", ATTR_DESC(transpose_x1, AnyTraits<bool>())}, | ATTR_MAP(MatMulV2) = {{"transpose_a", ATTR_DESC(transpose_x1, AnyTraits<bool>())}, | ||||
| {"transpose_b", ATTR_DESC(transpose_x2, AnyTraits<bool>())}}; | {"transpose_b", ATTR_DESC(transpose_x2, AnyTraits<bool>())}}; | ||||
| OUTPUT_MAP(MatMulV2) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(MatMulV2) = {{0, OUTPUT_DESC(y)}}; | ||||
| @@ -1347,7 +1347,8 @@ OUTPUT_MAP(AscendQuant) = {{0, OUTPUT_DESC(y)}}; | |||||
| // AscendDequant | // AscendDequant | ||||
| INPUT_MAP(AscendDequant) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(deq_scale)}}; | INPUT_MAP(AscendDequant) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(deq_scale)}}; | ||||
| ATTR_MAP(AscendDequant) = {{"sqrt_mode", ATTR_DESC(sqrt_mode, AnyTraits<bool>())}, | ATTR_MAP(AscendDequant) = {{"sqrt_mode", ATTR_DESC(sqrt_mode, AnyTraits<bool>())}, | ||||
| {"relu_flag", ATTR_DESC(relu_flag, AnyTraits<bool>())}}; | |||||
| {"relu_flag", ATTR_DESC(relu_flag, AnyTraits<bool>())}, | |||||
| {"dtype", ATTR_DESC(dtype, AnyTraits<GEType>())}}; | |||||
| OUTPUT_MAP(AscendDequant) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(AscendDequant) = {{0, OUTPUT_DESC(y)}}; | ||||
| #ifdef ENABLE_GE | #ifdef ENABLE_GE | ||||
| @@ -28,8 +28,8 @@ from mindspore._checkparam import check_int_positive, check_bool, twice | |||||
| from mindspore._checkparam import Rel | from mindspore._checkparam import Rel | ||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from .normalization import BatchNorm2d | |||||
| from .activation import get_activation | |||||
| from .normalization import BatchNorm2d, BatchNorm1d | |||||
| from .activation import get_activation, ReLU | |||||
| from ..cell import Cell | from ..cell import Cell | ||||
| from . import conv, basic | from . import conv, basic | ||||
| from ..._checkparam import ParamValidator as validator | from ..._checkparam import ParamValidator as validator | ||||
| @@ -206,7 +206,7 @@ class DenseBnAct(Cell): | |||||
| self.has_bn = validator.check_bool("has_bn", has_bn) | self.has_bn = validator.check_bool("has_bn", has_bn) | ||||
| self.has_act = activation is not None | self.has_act = activation is not None | ||||
| if has_bn: | if has_bn: | ||||
| self.batchnorm = BatchNorm2d(out_channels) | |||||
| self.batchnorm = BatchNorm1d(out_channels) | |||||
| self.activation = get_activation(activation) | self.activation = get_activation(activation) | ||||
| def construct(self, x): | def construct(self, x): | ||||
| @@ -1156,13 +1156,18 @@ class QuantBlock(Cell): | |||||
| self.has_bias = bias is not None | self.has_bias = bias is not None | ||||
| self.activation = activation | self.activation = activation | ||||
| self.has_act = activation is not None | self.has_act = activation is not None | ||||
| if isinstance(activation, ReLU): | |||||
| self.activation = None | |||||
| self.has_act = False | |||||
| self.dequant.add_prim_attr("relu_flag", True) | |||||
| self.bias_add = P.BiasAdd() | self.bias_add = P.BiasAdd() | ||||
| def construct(self, x): | def construct(self, x): | ||||
| x = self.quant(x) | x = self.quant(x) | ||||
| x = self.core_op(x, self.weight) | |||||
| if self.has_bias: | if self.has_bias: | ||||
| x = self.bias_add(x, self.bias) | |||||
| x = self.core_op(x, self.weight, self.bias) | |||||
| else: | |||||
| x = self.core_op(x, self.weight) | |||||
| if self.has_act: | if self.has_act: | ||||
| x = self.activation(x) | x = self.activation(x) | ||||
| x = self.dequant(x, self.dequant_scale) | x = self.dequant(x, self.dequant_scale) | ||||
| @@ -380,6 +380,7 @@ class Dequant(PrimitiveWithInfer): | |||||
| def __init__(self, sqrt_mode=False, relu_flag=False): | def __init__(self, sqrt_mode=False, relu_flag=False): | ||||
| self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) | self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) | ||||
| self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name) | self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name) | ||||
| self.add_prim_attr("dtype", mstype.float16) | |||||
| def infer_shape(self, x_shape, deq_scale_shape): | def infer_shape(self, x_shape, deq_scale_shape): | ||||
| return x_shape | return x_shape | ||||
| @@ -596,7 +596,7 @@ class MatMul(PrimitiveWithInfer): | |||||
| raise ValueError('MatMul input x, y should be the same dimension size and should be ' | raise ValueError('MatMul input x, y should be the same dimension size and should be ' | ||||
| + f'equal to 2, while x size = {len(x)}, y size= {len(y)}') | + f'equal to 2, while x size = {len(x)}, y size= {len(y)}') | ||||
| def infer_shape(self, x, y): | |||||
| def infer_shape(self, x, y, bias=None): | |||||
| self.check_shape_size(x, y) | self.check_shape_size(x, y) | ||||
| cls_name = self.name | cls_name = self.name | ||||
| # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two | # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two | ||||
| @@ -621,7 +621,7 @@ class MatMul(PrimitiveWithInfer): | |||||
| ret_dims = x[: -2] + [x_last[self.transpose_a], y_last[not self.transpose_b]] | ret_dims = x[: -2] + [x_last[self.transpose_a], y_last[not self.transpose_b]] | ||||
| return ret_dims | return ret_dims | ||||
| def infer_dtype(self, x, y): | |||||
| def infer_dtype(self, x, y, bias=None): | |||||
| args = {"x": x, "y": y} | args = {"x": x, "y": y} | ||||
| validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name) | validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name) | ||||
| if x.element_type() == mstype.int8: | if x.element_type() == mstype.int8: | ||||
| @@ -842,7 +842,7 @@ class Conv2D(PrimitiveWithInfer): | |||||
| self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) | self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) | ||||
| self.add_prim_attr('offset_a', 0) | self.add_prim_attr('offset_a', 0) | ||||
| def infer_shape(self, x_shape, w_shape): | |||||
| def infer_shape(self, x_shape, w_shape, b_shape=None): | |||||
| validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) | validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) | ||||
| validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) | validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) | ||||
| validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) | validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) | ||||
| @@ -887,7 +887,7 @@ class Conv2D(PrimitiveWithInfer): | |||||
| out_shape = [x_shape[0], out_channel, h_out, w_out] | out_shape = [x_shape[0], out_channel, h_out, w_out] | ||||
| return out_shape | return out_shape | ||||
| def infer_dtype(self, x_dtype, w_dtype): | |||||
| def infer_dtype(self, x_dtype, w_dtype, b_dtype=None): | |||||
| args = {'x': x_dtype, 'w': w_dtype} | args = {'x': x_dtype, 'w': w_dtype} | ||||
| valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] | valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] | ||||
| validator.check_tensor_type_same(args, valid_types, self.name) | validator.check_tensor_type_same(args, valid_types, self.name) | ||||
| @@ -968,7 +968,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): | |||||
| self.group = validator.check_integer("group", group, 0, Rel.GT, self.name) | self.group = validator.check_integer("group", group, 0, Rel.GT, self.name) | ||||
| self.add_prim_attr('offset_a', 0) | self.add_prim_attr('offset_a', 0) | ||||
| def infer_shape(self, x_shape, w_shape): | |||||
| def infer_shape(self, x_shape, w_shape, b_shape=None): | |||||
| validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) | validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) | ||||
| validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) | validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) | ||||
| validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) | validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) | ||||
| @@ -1011,7 +1011,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): | |||||
| out_shape = [x_shape[0], out_channel, h_out, w_out] | out_shape = [x_shape[0], out_channel, h_out, w_out] | ||||
| return out_shape | return out_shape | ||||
| def infer_dtype(self, x_dtype, w_dtype): | |||||
| def infer_dtype(self, x_dtype, w_dtype, b_dtype=None): | |||||
| args = {'x': x_dtype, 'w': w_dtype} | args = {'x': x_dtype, 'w': w_dtype} | ||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | validator.check_tensor_type_same(args, mstype.number_type, self.name) | ||||
| if x_dtype.element_type() == mstype.int8: | if x_dtype.element_type() == mstype.int8: | ||||
| @@ -78,7 +78,7 @@ def test_qat_lenet(): | |||||
| def test_qat_mobile_per_channel_tf(): | def test_qat_mobile_per_channel_tf(): | ||||
| network = mobilenetV2(num_classes=1000) | network = mobilenetV2(num_classes=1000) | ||||
| img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) | img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) | ||||
| network = qat.convert_quant_network(network, bn_fold=True, per_channel=[False, True], symmetric=[True, False]) | |||||
| network = qat.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) | |||||
| # should load the checkpoint. mock here | # should load the checkpoint. mock here | ||||
| for param in network.get_parameters(): | for param in network.get_parameters(): | ||||
| param.init_data() | param.init_data() | ||||