| @@ -487,6 +487,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||
| })); | |||
| (void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor") | |||
| .def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape")) | |||
| .def(py::pickle( | |||
| [](const MetaTensor &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(static_cast<int>(t.data_type()), t.shape()); | |||
| }, | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 2) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| /* Create a new C++ instance */ | |||
| MetaTensor tensor(TypeId(t[0].cast<int>()), t[1].cast<std::vector<int>>()); | |||
| return tensor; | |||
| })) | |||
| .def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_) | |||
| .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") | |||
| .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape."); | |||
| @@ -220,6 +220,8 @@ const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); | |||
| const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike"); | |||
| const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop"); | |||
| const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut"); | |||
| const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer"); | |||
| const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel"); | |||
| // Other miscellaneous | |||
| const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity"); | |||
| @@ -228,6 +228,8 @@ extern const PrimitivePtr kPrimActivation; | |||
| extern const PrimitivePtr kPrimZerosLike; | |||
| extern const PrimitivePtr kPrimFakeBprop; | |||
| extern const PrimitivePtr kPrimBpropCut; | |||
| extern const PrimitivePtr kPrimFakeQuantPerLayer; | |||
| extern const PrimitivePtr kPrimFakeQuantPerChannel; | |||
| // Other Miscellaneous | |||
| extern const PrimitivePtr kPrimIdentity; | |||
| @@ -77,6 +77,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| "Get CNode Strategy Dictionary.") | |||
| .def("get_allreduce_fusion", &ExecutorPy::GetAllreduceFusion, py::arg("phase") = py::str("train"), | |||
| "Get Allreduce Fusion Dictionary.") | |||
| .def("fetch_info_for_quant_export", &ExecutorPy::FetchInfoForQuantExport, py::arg("phase") = py::str("train"), | |||
| "Fetch the inputs of Conv or Matmul for quant export.") | |||
| .def("build_data_graph", &ExecutorPy::BuildGraph, py::arg("build_params"), py::arg("phase") = py::str("train"), | |||
| py::arg("broadcast_params") = py::dict(), "Build data graph.") | |||
| .def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.") | |||
| @@ -281,6 +281,75 @@ ExecutorPy::~ExecutorPy() { | |||
| ConfigManager::GetInstance().ResetConfig(); | |||
| } | |||
| std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchInfoForQuantExport( | |||
| const std::string &phase_s) { | |||
| FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!"; | |||
| std::map<std::string, std::pair<PrimitivePyPtr, std::string>> fake_quant_table; | |||
| auto filter = [](AnfNodePtr node) { | |||
| return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul)); | |||
| }; | |||
| std::vector<AnfNodePtr> nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter); | |||
| auto is_quant_cnode = [](AnfNodePtr node) { | |||
| return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) || | |||
| IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel); | |||
| }; | |||
| for (auto node : nodes) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr || cnode->size() != 3) { | |||
| continue; | |||
| } | |||
| auto x = cnode->input(1); | |||
| auto weight = cnode->input(2); | |||
| if (!is_quant_cnode(weight)) { | |||
| continue; | |||
| } | |||
| // get parameter weight's name | |||
| cnode = weight->cast<CNodePtr>(); | |||
| auto weight_node = cnode->input(2); | |||
| if (!weight_node->isa<Parameter>()) { | |||
| continue; | |||
| } | |||
| auto weight_name = weight_node->cast<ParameterPtr>()->name(); | |||
| // find the fakequant from input | |||
| int count = 0; | |||
| int max_depth = 5; | |||
| while (!is_quant_cnode(x)) { | |||
| if (count >= max_depth) { | |||
| break; | |||
| } | |||
| cnode = x->cast<CNodePtr>(); | |||
| if (cnode == nullptr || cnode->size() <= 1) { | |||
| break; | |||
| } | |||
| x = cnode->input(1); | |||
| count += 1; | |||
| } | |||
| // get the fakequant parameter minq's name | |||
| if (!is_quant_cnode(x)) { | |||
| continue; | |||
| } | |||
| cnode = x->cast<CNodePtr>(); | |||
| if (cnode == nullptr || cnode->size() != 4) { | |||
| continue; | |||
| } | |||
| auto fakequant_min_node = cnode->input(2); | |||
| if (!fakequant_min_node->isa<Parameter>()) { | |||
| continue; | |||
| } | |||
| auto fakequant_min_node_name = fakequant_min_node->cast<ParameterPtr>()->name(); | |||
| auto quant_op_value = cnode->input(0)->cast<ValueNodePtr>()->value(); | |||
| if (!quant_op_value->isa<PrimitivePy>()) { | |||
| continue; | |||
| } | |||
| auto quant_op = quant_op_value->cast<PrimitivePyPtr>(); | |||
| fake_quant_table[weight_name] = std::make_pair(quant_op, fakequant_min_node_name); | |||
| } | |||
| return fake_quant_table; | |||
| } | |||
| void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) { | |||
| // save the graph to ExecutorPy | |||
| FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); | |||
| @@ -97,6 +97,8 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> { | |||
| void ReleaseResource(const py::object &phase); | |||
| static void ClearRes(); | |||
| std::map<std::string, std::pair<PrimitivePyPtr, std::string>> FetchInfoForQuantExport(const std::string &phase_s); | |||
| private: | |||
| ExecutorPy(); | |||
| void ConvertObjectToTensors(const py::dict &dict, std::map<std::string, tensor::TensorPtr> *tensors); | |||
| @@ -39,6 +39,7 @@ namespace mindspore { | |||
| enum IncludeType { FOLLOW, NOFOLLOW, EXCLUDE }; | |||
| using IncludeFunc = std::function<IncludeType(const AnfNodePtr &)>; | |||
| using FilterFunc = std::function<bool(const AnfNodePtr &)>; | |||
| using SuccFunc = std::function<std::vector<AnfNodePtr>(AnfNodePtr)>; | |||
| using SearchFunc = std::function<std::vector<AnfNodePtr>(const AnfNodePtr &, const IncludeFunc &)>; | |||
| @@ -58,6 +59,9 @@ std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const Incl | |||
| std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); | |||
| std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); | |||
| std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include, | |||
| const FilterFunc &filter); | |||
| std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming, | |||
| const IncludeFunc &include = AlwaysInclude); | |||
| @@ -37,7 +37,8 @@ namespace mindspore { | |||
| namespace { | |||
| class DeepFirstSearcher : public AnfVisitor { | |||
| public: | |||
| explicit DeepFirstSearcher(const IncludeFunc &include) : include_(include) {} | |||
| explicit DeepFirstSearcher(const IncludeFunc &include, const FilterFunc &filter = nullptr) | |||
| : include_(include), filter_(filter) {} | |||
| ~DeepFirstSearcher() override = default; | |||
| std::vector<AnfNodePtr> Search(const AnfNodePtr &root) { | |||
| @@ -61,8 +62,9 @@ class DeepFirstSearcher : public AnfVisitor { | |||
| if (incl == EXCLUDE) { | |||
| return; | |||
| } | |||
| res_.push_back(node); | |||
| if (filter_ == nullptr || !filter_(node)) { | |||
| res_.push_back(node); | |||
| } | |||
| if (incl == FOLLOW) { | |||
| AnfVisitor::Visit(node); | |||
| } | |||
| @@ -71,6 +73,7 @@ class DeepFirstSearcher : public AnfVisitor { | |||
| private: | |||
| size_t seen_{0}; | |||
| IncludeFunc include_; | |||
| FilterFunc filter_; | |||
| std::vector<AnfNodePtr> res_{}; | |||
| }; | |||
| @@ -160,10 +163,16 @@ class DeepLinkedGraphSearcher : public DeepFirstSearcher { | |||
| }; | |||
| } // namespace | |||
| // include for if expand the node the search, filter for if put the node to results. | |||
| std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { | |||
| return DeepScopedGraphSearcher(include).Search(root); | |||
| } | |||
| std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include, | |||
| const FilterFunc &filter) { | |||
| return DeepFirstSearcher(include, filter).Search(root); | |||
| } | |||
| std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { | |||
| return DeepUsedGraphSearcher(include).Search(root); | |||
| } | |||
| @@ -526,6 +526,11 @@ class _Executor: | |||
| phase = 'export' + '.' + str(net.create_time) | |||
| export_graph(file_name, file_format, phase) | |||
| def fetch_info_for_quant_export(self, exec_id): | |||
| """Get graph proto from pipeline.""" | |||
| if self._executor.has_compiled(exec_id) is False: | |||
| return None | |||
| return self._executor.fetch_info_for_quant_export(exec_id) | |||
| _executor = _Executor() | |||
| _pynative_exec = _PynativeExecutor() | |||
| @@ -18,8 +18,6 @@ from mindspore.ops import functional as F | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.ops.primitive import constexpr | |||
| from mindspore.common.tensor import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.context as context | |||
| from mindspore._checkparam import check_bool, check_typename | |||
| from mindspore._extends import cell_attr_register | |||
| @@ -85,13 +83,12 @@ class _BatchNorm(Cell): | |||
| self.reshape = P.Reshape() | |||
| self.is_ascend = context.get_context("device_target") == "Ascend" | |||
| self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE | |||
| self.momentum = 1.0 - momentum | |||
| if context.get_context("enable_ge"): | |||
| self.is_ge_backend = True | |||
| self.momentum = Tensor(1.0 - momentum, mstype.float32) | |||
| else: | |||
| self.is_ge_backend = False | |||
| self.momentum = 1.0 - momentum | |||
| if self.is_graph_mode and (self.is_ge_backend or self.is_ascend): | |||
| self.bn_train = P.BatchNorm(is_training=True, | |||
| epsilon=self.eps) | |||
| @@ -729,8 +729,8 @@ class DenseQuant(Cell): | |||
| self.has_bias = check_bool(has_bias) | |||
| if isinstance(weight_init, Tensor): | |||
| if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ | |||
| weight_init.shape()[1] != in_channels: | |||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | |||
| weight_init.shape[1] != in_channels: | |||
| raise ValueError("weight_init shape error") | |||
| self.weight = Parameter(initializer( | |||
| @@ -738,7 +738,7 @@ class DenseQuant(Cell): | |||
| if self.has_bias: | |||
| if isinstance(bias_init, Tensor): | |||
| if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: | |||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | |||
| raise ValueError("bias_init shape error") | |||
| self.bias = Parameter(initializer( | |||
| @@ -780,8 +780,14 @@ class DenseQuant(Cell): | |||
| return str_info | |||
| class _QuantActivation(Cell): | |||
| r""" | |||
| Base class for Quant activation function. Add Fake Quant OP after activation OP. | |||
| """ | |||
| def get_origin(self): | |||
| raise NotImplementedError | |||
| class ReLUQuant(Cell): | |||
| class ReLUQuant(_QuantActivation): | |||
| r""" | |||
| ReLUQuant activation function. Add Fake Quant OP after Relu OP. | |||
| @@ -828,8 +834,11 @@ class ReLUQuant(Cell): | |||
| x = self.fake_quant_act(x) | |||
| return x | |||
| def get_origin(self): | |||
| return self.relu | |||
| class ReLU6Quant(Cell): | |||
| class ReLU6Quant(_QuantActivation): | |||
| r""" | |||
| ReLU6Quant activation function. | |||
| @@ -878,8 +887,10 @@ class ReLU6Quant(Cell): | |||
| x = self.fake_quant_act(x) | |||
| return x | |||
| def get_origin(self): | |||
| return self.relu6 | |||
| class HSwishQuant(Cell): | |||
| class HSwishQuant(_QuantActivation): | |||
| r""" | |||
| HSwishQuant activation function. Add Fake Quant OP after HSwish OP. | |||
| @@ -935,8 +946,10 @@ class HSwishQuant(Cell): | |||
| x = self.fake_quant_act_after(x) | |||
| return x | |||
| def get_origin(self): | |||
| return self.act | |||
| class HSigmoidQuant(Cell): | |||
| class HSigmoidQuant(_QuantActivation): | |||
| r""" | |||
| HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP. | |||
| @@ -991,6 +1004,8 @@ class HSigmoidQuant(Cell): | |||
| x = self.fake_quant_act_after(x) | |||
| return x | |||
| def get_origin(self): | |||
| return self.act | |||
| class TensorAddQuant(Cell): | |||
| r""" | |||
| @@ -1083,3 +1098,77 @@ class MulQuant(Cell): | |||
| x = self.mul(x1, x2) | |||
| x = self.fake_quant_act(x) | |||
| return x | |||
| class QuantBlock(Cell): | |||
| r""" | |||
| A quant block of Conv/Dense, activation layer for Ascend deploy. | |||
| Calculate Conv or Dense in Int8, with AscendQuant and AscendDeQuant. | |||
| Notes: | |||
| This block is only for deploy, and not trainable. | |||
| Args: | |||
| in_channels (int): The number of channels in the input space. | |||
| out_channels (int): The number of channels in the output space. | |||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype | |||
| is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. | |||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is | |||
| same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. | |||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. | |||
| activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. | |||
| batchnorm (bool): Specifies to used batchnorm or not. Default: None. | |||
| activation (string): Specifies activation type. The optional values are as following: | |||
| 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', | |||
| 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. | |||
| Inputs: | |||
| - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. | |||
| Outputs: | |||
| Tensor of shape :math:`(N, out\_channels)`. | |||
| Examples: | |||
| >>> net = nn.Dense(3, 4) | |||
| >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) | |||
| >>> net(input) | |||
| """ | |||
| def __init__(self, | |||
| core_op, | |||
| weight, | |||
| quant_op, | |||
| dequant_op, | |||
| dequant_scale, | |||
| bias=None, | |||
| activation=None): | |||
| super(QuantBlock, self).__init__() | |||
| self.core_op = core_op | |||
| self.weight = weight | |||
| self.quant = quant_op | |||
| self.dequant = dequant_op | |||
| self.dequant_scale = dequant_scale | |||
| self.bias = bias | |||
| self.has_bias = bias is None | |||
| self.activation = activation | |||
| self.has_act = activation is None | |||
| def construct(self, x): | |||
| x = self.quant(x) | |||
| x = self.core_op(x, self.weight) | |||
| if self.has_bias: | |||
| output = self.bias_add(output, self.bias) | |||
| if self.has_act: | |||
| x = self.activation(x) | |||
| x = self.dequant(x, self.dequant_scale) | |||
| return x | |||
| def extend_repr(self): | |||
| str_info = f'quant={self.quant}, core_op={type(self.core_op)}' | |||
| if self.has_bias: | |||
| str_info = str_info + f', bias={self.bias}' | |||
| if self.has_act: | |||
| str_info = str_info + f', activation={self.activation}' | |||
| str_info = str_info + f', dequant={self.dequant}' | |||
| return str_info | |||
| @@ -584,6 +584,8 @@ class MatMul(PrimitiveWithInfer): | |||
| def infer_dtype(self, x, y): | |||
| args = {"x": x, "y": y} | |||
| validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name) | |||
| if x.element_type() == mstype.int8: | |||
| return mstype.tensor_type(mstype.int32) | |||
| return x | |||
| @@ -800,7 +800,7 @@ class Conv2D(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape, w_shape): | |||
| validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) | |||
| validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) | |||
| validator.check("x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) | |||
| validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) | |||
| validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name) | |||
| validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) | |||
| @@ -846,6 +846,8 @@ class Conv2D(PrimitiveWithInfer): | |||
| args = {'x': x_dtype, 'w': w_dtype} | |||
| valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] | |||
| validator.check_tensor_type_same(args, valid_types, self.name) | |||
| if x_dtype.element_type() == mstype.int8: | |||
| return mstype.tensor_type(mstype.int32) | |||
| return x_dtype | |||
| @@ -43,11 +43,12 @@ class Primitive(Primitive_): | |||
| >>> # init a Primitive obj with attr1=1 and attr2=2 | |||
| >>> add = Add(attr1=1, attr2=2) | |||
| """ | |||
| _repr_ignore_list = ['input_names', 'output_names'] | |||
| def __init__(self, name): | |||
| self.name = name | |||
| self.attrs = {} | |||
| self.init_attrs = {} | |||
| self.init_attrs = {"name": name} | |||
| Primitive_.__init__(self, name, self) | |||
| if hasattr(self.__class__, '__mindspore_signature__'): | |||
| sig = self._fill_signature(self.__class__.__mindspore_signature__) | |||
| @@ -165,6 +166,16 @@ class Primitive(Primitive_): | |||
| def __setstate__(self, d): | |||
| self.__dict__.update(d) | |||
| def __deepcopy__(self, memo): | |||
| return type(self)(**self.init_attrs) | |||
| def __repr__(self): | |||
| attr = ', '.join([f'{k}={self.attrs[k]}'for k in self.attrs if not k in Primitive._repr_ignore_list]) | |||
| info_str = f'Prim[{self.name}]' | |||
| if attr: | |||
| info_str += f'<{attr}>' | |||
| return info_str | |||
| def init_prim_io_names(self, inputs, outputs): | |||
| """ | |||
| Initializes inputs and outpus name of Tensor or attributes. | |||
| @@ -185,8 +196,8 @@ class PrimitiveWithInfer(Primitive): | |||
| There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(), | |||
| infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority | |||
| to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describle shape | |||
| and type infer logic. The infer_value() is used for constant propogation. | |||
| to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe shape | |||
| and type infer logic. The infer_value() is used for constant propagation. | |||
| Args: | |||
| name (str): Name for current Primitive. | |||
| @@ -288,6 +299,7 @@ def prim_attr_register(fn): | |||
| bound_args.apply_defaults() | |||
| arguments = bound_args.arguments | |||
| del arguments['self'] | |||
| del self.init_attrs['name'] | |||
| for name in arguments: | |||
| value = arguments[name] | |||
| self.add_prim_attr(name, value) | |||
| @@ -14,12 +14,23 @@ | |||
| # ============================================================================ | |||
| """aware quantization.""" | |||
| import copy | |||
| import re | |||
| from ... import nn | |||
| from ... import ops | |||
| import numpy as np | |||
| from ... import log as logger | |||
| from ... import nn, ops | |||
| from ..._checkparam import ParamValidator as validator | |||
| from ..._checkparam import Rel | |||
| from ...common import Tensor | |||
| from ...common import dtype as mstype | |||
| from ...common.api import _executor | |||
| from ...nn.layer import quant | |||
| from ...ops import functional as F | |||
| from ...ops.operations import _inner_ops as inner | |||
| from ...train import serialization | |||
| from . import quant_utils | |||
| _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, | |||
| nn.ReLU6: quant.ReLU6Quant, | |||
| @@ -27,25 +38,21 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, | |||
| nn.HSwish: quant.HSwishQuant} | |||
| class _AddFakeQuantInputOutput(nn.Cell): | |||
| class _AddFakeQuantInput(nn.Cell): | |||
| """ | |||
| Add FakeQuant at input and output of the Network. Only support one input and one output case. | |||
| """ | |||
| def __init__(self, network, quant_delay=0): | |||
| super(_AddFakeQuantInputOutput, self).__init__(auto_prefix=False) | |||
| super(_AddFakeQuantInput, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.fake_quant_input = quant.FakeQuantWithMinMax( | |||
| min_init=-6, max_init=6, quant_delay=quant_delay, ema=True) | |||
| self.fake_quant_input.update_parameters_name('fake_quant_input') | |||
| self.fake_quant_output = quant.FakeQuantWithMinMax( | |||
| min_init=-6, max_init=6, quant_delay=quant_delay, ema=True) | |||
| self.fake_quant_output.update_parameters_name('fake_quant_output') | |||
| def construct(self, data): | |||
| data = self.fake_quant_input(data) | |||
| output = self.network(data) | |||
| output = self.fake_quant_output(output) | |||
| return output | |||
| @@ -99,6 +106,8 @@ class ConvertToQuantNetwork: | |||
| self.per_channel = validator.check_bool("per channel", per_channel) | |||
| self.symmetric = validator.check_bool("symmetric", symmetric) | |||
| self.narrow_range = validator.check_bool("narrow range", narrow_range) | |||
| self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, | |||
| quant.DenseBnAct: self._convert_dense} | |||
| def _convert_op_name(self, name): | |||
| pattern = re.compile(r'([A-Z]{1})') | |||
| @@ -110,6 +119,7 @@ class ConvertToQuantNetwork: | |||
| def run(self): | |||
| self.network.update_cell_prefix() | |||
| network = self._convert_subcells2quant(self.network) | |||
| network = _AddFakeQuantInput(network) | |||
| return network | |||
| def _convert_subcells2quant(self, network): | |||
| @@ -122,15 +132,9 @@ class ConvertToQuantNetwork: | |||
| subcell = cells[name] | |||
| if subcell == network: | |||
| continue | |||
| elif isinstance(subcell, quant.Conv2dBnAct): | |||
| elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)): | |||
| prefix = subcell.param_prefix | |||
| new_subcell = self._convert_conv(subcell) | |||
| new_subcell.update_parameters_name(prefix + '.') | |||
| network.insert_child_to_cell(name, new_subcell) | |||
| change = True | |||
| elif isinstance(subcell, quant.DenseBnAct): | |||
| prefix = subcell.param_prefix | |||
| new_subcell = self._convert_dense(subcell) | |||
| new_subcell = self._convert_method_map[type(subcell)](subcell) | |||
| new_subcell.update_parameters_name(prefix + '.') | |||
| network.insert_child_to_cell(name, new_subcell) | |||
| change = True | |||
| @@ -199,10 +203,12 @@ class ConvertToQuantNetwork: | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range) | |||
| subcell.conv = conv_inner | |||
| if subcell.activation is not None: | |||
| if subcell.has_act and subcell.activation is not None: | |||
| subcell.activation = self._convert_activation(subcell.activation) | |||
| else: | |||
| subcell = _AddFakeQuantAfterSubCell(subcell) | |||
| subcell.has_act = True | |||
| subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits, | |||
| quant_delay=self.quant_delay) | |||
| return subcell | |||
| def _convert_dense(self, subcell): | |||
| @@ -217,8 +223,12 @@ class ConvertToQuantNetwork: | |||
| per_channel=self.per_channel, | |||
| num_bits=self.weight_bits) | |||
| subcell.dense = dense_inner | |||
| if subcell.activation is not None: | |||
| if subcell.has_act and subcell.activation is not None: | |||
| subcell.activation = self._convert_activation(subcell.activation) | |||
| else: | |||
| subcell.has_act = True | |||
| subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits, | |||
| quant_delay=self.quant_delay) | |||
| return subcell | |||
| def _convert_activation(self, activation): | |||
| @@ -229,6 +239,147 @@ class ConvertToQuantNetwork: | |||
| return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, quant_delay=self.quant_delay) | |||
| class ExportQuantNetworkDeploy: | |||
| """ | |||
| Convert quantization aware network to deploy network. | |||
| Args: | |||
| network (Cell): MindSpore network produced by `convert_quant_network`. | |||
| inputs (Tensor): Inputs of the `network`. | |||
| Returns: | |||
| Cell, converted network. | |||
| """ | |||
| __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] | |||
| def __init__(self, | |||
| network, | |||
| *inputs): | |||
| network = validator.check_isinstance('network', network, (nn.Cell,)) | |||
| self.data_type = mstype.int8 | |||
| self.network = copy.deepcopy(network) | |||
| self.all_paramters = {p.name: p for p in self.network.get_parameters()} | |||
| self.get_inputs_table(inputs) | |||
| def get_inputs_table(self, inputs): | |||
| """Get the support info for quant export.""" | |||
| phase_name = 'export_quant' | |||
| graph_id, _ = _executor.compile(self.network, *inputs, phase=phase_name, do_convert=False) | |||
| self.quant_info_table = _executor.fetch_info_for_quant_export(graph_id) | |||
| def run(self): | |||
| """Start to convert.""" | |||
| self.network.update_cell_prefix() | |||
| network = self.network | |||
| if isinstance(network, _AddFakeQuantInput): | |||
| network = network.network | |||
| network = self._convert_quant2deploy(network) | |||
| return network | |||
| def _get_quant_block(self, cell_core, activation, fake_quant_a_out): | |||
| """convet network's quant subcell to deploy subcell""" | |||
| # Calculate the scale and zero point | |||
| w_minq_name = cell_core.fake_quant_weight.minq.name | |||
| np_type = mstype.dtype_to_nptype(self.data_type) | |||
| scale_w, zp_w = quant_utils.scale_zp_from_fack_quant_cell(cell_core.fake_quant_weight, np_type) | |||
| scale_a_out, _ = quant_utils.scale_zp_from_fack_quant_cell(fake_quant_a_out, np_type) | |||
| info = self.quant_info_table.get(w_minq_name, None) | |||
| if info: | |||
| fack_quant_a_in_op, minq_name = info | |||
| maxq = self.all_paramters[minq_name[:-4] + "maxq"] | |||
| minq = self.all_paramters[minq_name] | |||
| scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) | |||
| else: | |||
| logger.warning(f"Do not find `fake_quant` from input with `fack_quant.minq` {w_minq_name}") | |||
| return None | |||
| # Build the `Quant` `Dequant` op. | |||
| # AscendQuant only support perlayer version. Need check here. | |||
| quant_op = inner.AscendQuant(float(scale_a_in), float(zp_a_in)) | |||
| sqrt_mode = False | |||
| scale_deq = scale_a_out * scale_w | |||
| if scale_deq < 2 ** -14: | |||
| scale_deq = np.sqrt(scale_deq) | |||
| sqrt_mode = True | |||
| dequant_op = inner.AscendDequant(sqrt_mode) | |||
| # get op | |||
| op_core = cell_core.matmul if isinstance(cell_core, quant.DenseQuant) else cell_core.conv | |||
| if isinstance(activation, _AddFakeQuantAfterSubCell): | |||
| activation = activation.subcell | |||
| elif hasattr(activation, "get_origin"): | |||
| activation = activation.get_origin() | |||
| # get the `weight` and `bias` | |||
| weight = cell_core.weight.data.asnumpy() | |||
| bias = None | |||
| if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)): | |||
| if cell_core.has_bias: | |||
| bias = cell_core.bias.data.asnumpy() | |||
| elif isinstance(cell_core, quant.Conv2dBatchNormQuant): | |||
| weight, bias = quant_utils.fold_batchnorm(weight, cell_core) | |||
| # apply the quant | |||
| weight = Tensor(quant_utils.weight2int(weight, scale_w, zp_w), self.data_type) | |||
| if bias is not None: | |||
| bias = Tensor(scale_a_in * scale_w * bias, mstype.int32) | |||
| scale_deq = Tensor(scale_deq, mstype.float16) | |||
| block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation) | |||
| return block | |||
| def _convert_quant2deploy(self, network): | |||
| """Convet network's all quant subcell to deploy subcell.""" | |||
| cells = network.name_cells() | |||
| change = False | |||
| for name in cells: | |||
| subcell = cells[name] | |||
| if subcell == network: | |||
| continue | |||
| cell_core = None | |||
| fake_quant_act = None | |||
| activation = None | |||
| if isinstance(subcell, quant.Conv2dBnAct): | |||
| cell_core = subcell.conv | |||
| activation = subcell.activation | |||
| fake_quant_act = activation.fake_quant_act | |||
| elif isinstance(subcell, quant.DenseBnAct): | |||
| cell_core = subcell.dense | |||
| activation = subcell.activation | |||
| fake_quant_act = activation.fake_quant_act | |||
| if cell_core is not None: | |||
| new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act) | |||
| if new_subcell: | |||
| prefix = subcell.param_prefix | |||
| new_subcell.update_parameters_name(prefix + '.') | |||
| network.insert_child_to_cell(name, new_subcell) | |||
| change = True | |||
| elif isinstance(subcell, _AddFakeQuantAfterSubCell): | |||
| op = subcell.subcell | |||
| if op.name in ConvertToQuantNetwork.__quant_op_name__ and isinstance(op, ops.Primitive): | |||
| network.__delattr__(name) | |||
| network.__setattr__(name, op) | |||
| change = True | |||
| else: | |||
| self._convert_quant2deploy(subcell) | |||
| if isinstance(network, nn.SequentialCell) and change: | |||
| network.cell_list = list(network.cells()) | |||
| return network | |||
| def export_geir(network, *inputs, file_name): | |||
| """ | |||
| Exports MindSpore quant predict model to deploy with GEIR. | |||
| Args: | |||
| network (Cell): MindSpore network produced by `convert_quant_network`. | |||
| inputs (Tensor): Inputs of the `network`. | |||
| file_name (str): File name of model to export. | |||
| """ | |||
| exporter = ExportQuantNetworkDeploy(network, *inputs) | |||
| deploy_net = exporter.run() | |||
| serialization.export(deploy_net, *inputs, file_name=file_name, file_format="GEIR") | |||
| def convert_quant_network(network, | |||
| quant_delay=0, | |||
| bn_fold=False, | |||
| @@ -12,7 +12,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """quantization utils.""" | |||
| """Quantization utils.""" | |||
| import numpy as np | |||
| @@ -24,22 +24,19 @@ def cal_quantization_params(input_min, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| r""" | |||
| calculate quantization params for scale and zero point. | |||
| Calculate quantization params for scale and zero point. | |||
| Args: | |||
| input_min (int, list): The dimension of channel or 1. | |||
| input_max (int, list): The dimension of channel or 1. | |||
| input_min (numpy.ndarray): The dimension of channel or 1. | |||
| input_max (numpy.ndarray): The dimension of channel or 1. | |||
| data_type (numpy type) : Can ben numpy int8, numpy uint8. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| Outputs: | |||
| scale (int, list): quantization param. | |||
| zero point (int, list): quantization param. | |||
| Examples: | |||
| >>> scale, zp = cal_quantization_params([1, 2, 1], [-2, 0, -1], 8, False, False) | |||
| Returns: | |||
| scale (numpy.ndarray): quantization param. | |||
| zero point (numpy.ndarray): quantization param. | |||
| """ | |||
| input_max = np.maximum(0.0, input_max) | |||
| input_min = np.minimum(0.0, input_min) | |||
| @@ -92,27 +89,103 @@ def weight2int(data, | |||
| scale, | |||
| zero_point): | |||
| r""" | |||
| calculate int8/uint8 weight from fp32. the formula is defined as: | |||
| Calculate int8/uint8 weight from fp32. the formula is defined as: | |||
| .. math:: | |||
| int8/uint8 = round(float/scale) + offset | |||
| Args: | |||
| data (int, list): The dimension of channel or 1. Should be NCHW. | |||
| scale (int, list): The dimension of channel or 1. | |||
| zero_point (int, list): The dimension of channel or 1. | |||
| data (numpy.ndarray): The dimension of channel or 1. Should be NCHW. | |||
| scale (numpy.ndarray): The dimension of channel or 1. | |||
| zero_point (numpy.ndarray): The dimension of channel or 1. | |||
| Outputs: | |||
| weight (int, list): The dimension of channel or 1. | |||
| Examples: | |||
| >>> weight = weight2int([1, 2, 1], 1, 0) | |||
| Returns: | |||
| weight (numpy.ndarray): The dimension of channel or 1. | |||
| """ | |||
| if scale.shape != zero_point.shape: | |||
| raise ValueError("scale and zero_point should have the same shape.") | |||
| if scale.shape[0] > 0: | |||
| scale = scale.reshape(1, -1, 1, 1) | |||
| zero_point = zero_point.reshape(1, -1, 1, 1) | |||
| scale = scale.reshape(1, -1) | |||
| zero_point = zero_point.reshape(1, -1) | |||
| return np.round((data/scale) + zero_point) | |||
| def scale_zp_from_fack_quant_cell(cell, data_type): | |||
| r""" | |||
| Get calculate quantization params for scale and zero point From `FakeQuantWithMinMax`. | |||
| Args: | |||
| cell (Cell): `mindspore.nn.layer.FakeQuantWithMinMax` | |||
| data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`. | |||
| Returns: | |||
| scale (numpy.ndarray): quantization param. | |||
| zero point (numpy.ndarray): quantization param. | |||
| """ | |||
| minq = cell.minq.data.asnumpy() | |||
| maxq = cell.maxq.data.asnumpy() | |||
| op = cell.fake_quant | |||
| scale, zp = cal_quantization_params( | |||
| minq, maxq, data_type, | |||
| num_bits=op.num_bits, | |||
| symmetric=op.symmetric, | |||
| narrow_range=op.narrow_range) | |||
| return scale, zp | |||
| def scale_zp_from_data(op, minq, maxq, data_type): | |||
| r""" | |||
| Get calculate quantization params for scale and zero point. | |||
| Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive. | |||
| Args: | |||
| op (Primitive): Fake quant primitive `mindspore.ops.operation.FakeQuantPerLayer` or | |||
| `mindspore.ops.operation.FakeQuantPerChannel` | |||
| minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax` | |||
| maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax` | |||
| data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`. | |||
| Returns: | |||
| scale (numpy.ndarray): quantization param. | |||
| zero point (numpy.ndarray): quantization param. | |||
| """ | |||
| minq = minq.data.asnumpy() | |||
| maxq = maxq.data.asnumpy() | |||
| scale, zp = cal_quantization_params( | |||
| minq, maxq, data_type, | |||
| num_bits=op.num_bits, | |||
| symmetric=op.symmetric, | |||
| narrow_range=op.narrow_range) | |||
| return scale, zp | |||
| def fold_batchnorm(weight, cell_quant): | |||
| r""" | |||
| Fold the batchnorm in `Conv2dBatchNormQuant` to weight. | |||
| Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive. | |||
| Args: | |||
| weight (numpy.ndarray): Weight of `cell_quant`. | |||
| cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBatchNormQuant`. | |||
| Returns: | |||
| weight (numpy.ndarray): Folded weight. | |||
| bias (numpy.ndarray): Folded bias. | |||
| """ | |||
| variance = cell_quant.moving_variance.data.asnumpy() | |||
| mean = cell_quant.moving_mean.data.asnumpy() | |||
| gamma = cell_quant.gamma.data.asnumpy() | |||
| beta = cell_quant.beta.data.asnumpy() | |||
| epsilon = cell_quant.eps | |||
| sigma = np.sqrt(variance + epsilon) | |||
| gamma = gamma.reshape(-1, 1, 1, 1) | |||
| sigma = sigma.reshape(-1, 1, 1, 1) | |||
| mean = mean.reshape(-1, 1, 1, 1) | |||
| weight = weight * gamma / sigma | |||
| bias = beta - gamma * mean / sigma | |||
| return weight, bias | |||
| @@ -55,7 +55,7 @@ def init_net_param(network, init_value='ones'): | |||
| params = network.trainable_params() | |||
| for p in params: | |||
| if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: | |||
| p.set_parameter_data(initializer(init_value, p.data.shape(), p.data.dtype())) | |||
| p.set_parameter_data(initializer(init_value, p.data.shape, p.data.dtype)) | |||
| class ModelCallback(Callback): | |||
| def __init__(self): | |||
| @@ -13,9 +13,14 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ tests for quant """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore import nn | |||
| from mindspore.train.quant import quant as qat | |||
| from mobilenetv2_combined import MobileNetV2 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| @@ -37,23 +42,45 @@ class LeNet5(nn.Cell): | |||
| def __init__(self, num_class=10): | |||
| super(LeNet5, self).__init__() | |||
| self.num_class = num_class | |||
| self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6') | |||
| self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu') | |||
| self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6', pad_mode="valid") | |||
| self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid") | |||
| self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') | |||
| self.fc2 = nn.DenseBnAct(120, 84, activation='relu') | |||
| self.fc3 = nn.DenseBnAct(84, self.num_class) | |||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.flattern = nn.Flatten() | |||
| self.flatten = nn.Flatten() | |||
| def construct(self, x): | |||
| x = self.conv1(x) | |||
| x = self.bn(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.conv2(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.flattern(x) | |||
| x = self.flatten(x) | |||
| x = self.fc1(x) | |||
| x = self.fc2(x) | |||
| x = self.fc3(x) | |||
| return x | |||
| @pytest.mark.skip(reason="no `te.lang.cce` in ut env") | |||
| def test_qat_lenet(): | |||
| img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) | |||
| net = LeNet5() | |||
| net = qat.convert_quant_network( | |||
| net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8) | |||
| # should load the checkpoint. mock here | |||
| for param in net.get_parameters(): | |||
| param.init_data() | |||
| qat.export_geir(net, img, file_name="quant.pb") | |||
| @pytest.mark.skip(reason="no `te.lang.cce` in ut env") | |||
| def test_qat_mobile(): | |||
| net = MobileNetV2() | |||
| img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) | |||
| net = qat.convert_quant_network( | |||
| net, quant_delay=0, bn_fold=True, freeze_bn=10000, weight_bits=8, act_bits=8) | |||
| # should load the checkpoint. mock here | |||
| for param in net.get_parameters(): | |||
| param.init_data() | |||
| qat.export_geir(net, img, file_name="quant.pb") | |||