fix pylint problem fix conflict fix op list fix check warning fix code based on review comments update akg commit fix check warningtags/v1.4.0
| @@ -1 +1 @@ | |||
| Subproject commit f3168164c452316c21709f3293ef3b31a3688062 | |||
| Subproject commit 4aac4d95750a87e664f175c0fa946a069f8a0c2a | |||
| @@ -14,7 +14,7 @@ | |||
| # =========================================================================== | |||
| """Cost model splitter""" | |||
| import os | |||
| from functools import reduce | |||
| from functools import reduce as prod_reduce | |||
| from mindspore import log as logger | |||
| from .model import PrimLib, Graph, Tensor, Operator | |||
| from .model import DataFormat as DF | |||
| @@ -98,6 +98,7 @@ class GraphSplitByPattern: | |||
| return str(self) | |||
| def get_relation(self, op, i): | |||
| """Get op relation""" | |||
| relation = PrimLib.UNKNOWN | |||
| _, elem_relation = PrimLib.input_relation(op, i) | |||
| for r in elem_relation: | |||
| @@ -122,6 +123,7 @@ class GraphSplitByPattern: | |||
| self.reach_tab.sync(self.unique_id, out.unique_id) | |||
| def update_stitch_info(self, stitch_info): | |||
| """Update stitch info""" | |||
| if stitch_info.stitch_ops: | |||
| self.stitch_info.stitch_ops.update(stitch_info.stitch_ops) | |||
| if stitch_info.stitch_atomic_ops: | |||
| @@ -180,9 +182,11 @@ class GraphSplitByPattern: | |||
| return True | |||
| def dom_op(self): | |||
| """Get dom op""" | |||
| return self.ops[0] | |||
| def reduce_out_exclude(self, area): | |||
| """Check whether op is redcue_out_exclude """ | |||
| if self.output_excluded: | |||
| for op in self.output_excluded: | |||
| if op in area.ops: | |||
| @@ -260,6 +264,7 @@ class GraphSplitByPattern: | |||
| self.area_map[op] = area | |||
| def set_default_mode(self, area): | |||
| """Set default mode""" | |||
| area.mode = self.get_default_mode(area.ops[0]) | |||
| def limit_area_size(self, dominant, fuse_areas): | |||
| @@ -267,7 +272,7 @@ class GraphSplitByPattern: | |||
| limit_size = 200 # an experience number | |||
| area_sizes = map(lambda area: len(area.ops), fuse_areas) | |||
| dom_size = len(dominant.ops) | |||
| if dom_size + reduce(lambda x, y: x+y, area_sizes) <= limit_size: | |||
| if dom_size + prod_reduce(lambda x, y: x + y, area_sizes) <= limit_size: | |||
| return fuse_areas | |||
| # fuse the smaller area in priority | |||
| fuse_areas.sort(key=lambda area: len(area.ops)) | |||
| @@ -358,8 +363,9 @@ class GraphSplitByPattern: | |||
| with os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), 'w+') as f: | |||
| f.write(subgraphs_str) | |||
| def pattern_fuse(self, select=None): | |||
| def pattern_fuse(self, fuse_func=None): | |||
| """fuse Areas by pattern repeatedly""" | |||
| del fuse_func | |||
| raise Exception("pattern_fuse() is not implemented in {}".format(self.__class__.__name__)) | |||
| def split(self): | |||
| @@ -566,6 +572,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| REDUCE_FUSE_DEPTH = 20 | |||
| def get_default_mode(self, op): | |||
| """Get default mode in GPU""" | |||
| if op.prim == "MatMul": | |||
| return self.Area.MODE_COMPOSITE if op.inputs[0].dtype == "float16" and op.attrs['Akg'] else \ | |||
| self.Area.MODE_BASIC | |||
| @@ -696,9 +703,14 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| if any(["Reduce" in x.prim for x in dom.ops[1:]]): | |||
| return False | |||
| op = dom.ops[0] | |||
| reduce_axis = op.attrs["reduce_axis"] | |||
| if "reduce_axis" in op.attrs: | |||
| reduce_axis = op.attrs["reduce_axis"] | |||
| elif "axis" in op.attrs: | |||
| reduce_axis = [op.attrs["axis"]] | |||
| else: | |||
| raise Exception("the operator has no attr reduce_axis or axis") | |||
| if len(op.inputs[0].shape) - 1 in reduce_axis: | |||
| reduce_size = reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis]) | |||
| reduce_size = prod_reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis]) | |||
| return reduce_size >= 1024 | |||
| return True | |||
| @@ -753,7 +765,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a): | |||
| if _reduce_nums(a.ops) < 2: | |||
| dom_outs = [op.output for op in dom.ops] | |||
| a_ins = [input for op in a.ops for input in op.inputs] | |||
| a_ins = [op_input for op in a.ops for op_input in op.inputs] | |||
| a_outs = [op.output for op in a.ops] | |||
| a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins] | |||
| stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins] | |||
| @@ -832,7 +844,7 @@ class GraphSplitAscend(GraphSplitByPattern): | |||
| REDUCE_FUSE_DEPTH = 10 | |||
| def get_default_mode(self, op): | |||
| """Get efault mode for op""" | |||
| """Get efault mode for Ascend""" | |||
| def _dtype_same(tensors): | |||
| dtype = tensors[0].dtype | |||
| for tensor_ in tensors: | |||
| @@ -17,6 +17,9 @@ | |||
| class Utils: | |||
| """Model utils""" | |||
| def __init__(self): | |||
| pass | |||
| @staticmethod | |||
| def get_attr_type(attr): | |||
| """Get attr type""" | |||
| @@ -54,6 +57,9 @@ class DataFormat: | |||
| FRACTAL_Z_C04 = "FRACTAL_Z_C04" | |||
| NDHWC = "NDHWC" | |||
| def __init__(self): | |||
| pass | |||
| class DataType: | |||
| """Data Type""" | |||
| @@ -73,11 +79,8 @@ class DataType: | |||
| UINT64 = "uint64" | |||
| BOOL = "bool" | |||
| class Config: | |||
| R0 = 8.0 | |||
| UB_SIZE = 256 * 1024 | |||
| MAX_BLOCK = 32 | |||
| def __init__(self): | |||
| pass | |||
| class PrimLib: | |||
| @@ -90,6 +93,9 @@ class PrimLib: | |||
| REDUCE = 4 | |||
| OPAQUE = 5 | |||
| def __init__(self): | |||
| pass | |||
| class Prim: | |||
| """Prim""" | |||
| @@ -101,6 +107,7 @@ class PrimLib: | |||
| self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x) | |||
| def default_reshape_relation(self, op, input_idx): | |||
| """Process reshape relation""" | |||
| axis_relation, elem_relation = self.unknown_relation(op, input_idx) | |||
| elem_relation = [PrimLib.RESHAPE] * len(elem_relation) | |||
| return axis_relation, elem_relation | |||
| @@ -189,6 +196,8 @@ class PrimLib: | |||
| 'ReduceSum': Prim(REDUCE), | |||
| 'ReduceMax': Prim(REDUCE), | |||
| 'ReduceMin': Prim(REDUCE), | |||
| 'Argmax': Prim(REDUCE), | |||
| 'Argmin': Prim(REDUCE), | |||
| 'Assign': Prim(ELEMWISE), | |||
| 'Sign': Prim(ELEMWISE), | |||
| 'Sin': Prim(ELEMWISE), | |||
| @@ -225,6 +234,7 @@ class PrimLib: | |||
| @classmethod | |||
| def get_prim(cls, op): | |||
| """Get op primtive""" | |||
| prim = cls.primtives.get(op.prim, None) | |||
| if prim is None: | |||
| print('[WARN] primtive is not registered: ' + op.prim) | |||
| @@ -233,22 +243,27 @@ class PrimLib: | |||
| @classmethod | |||
| def input_relation(cls, op, input_idx): | |||
| """Get op's input_relation according to input_idx""" | |||
| return cls.get_prim(op).relation_func(op, input_idx) | |||
| @classmethod | |||
| def iter_type(cls, op): | |||
| """Get op's iter type""" | |||
| return cls.get_prim(op).iter_type | |||
| @classmethod | |||
| def is_reduce(cls, op): | |||
| """Check whether op's iter type is reduce""" | |||
| return cls.get_prim(op).iter_type == cls.REDUCE | |||
| @classmethod | |||
| def calibrate_iter_size(cls, op, iter_size): | |||
| """Get calibrate_iter_size""" | |||
| return cls.get_prim(op).calibrate * iter_size | |||
| @classmethod | |||
| def dtype_bytes(cls, dtype): | |||
| """Get dtype bytes""" | |||
| bits, unit = 1, 1 | |||
| for i in range(len(dtype) - 1, 0, -1): | |||
| if dtype[i].isdecimal(): | |||
| @@ -260,6 +275,7 @@ class PrimLib: | |||
| @classmethod | |||
| def inplace_reuse(cls, op, input_idx, start_axis=0): | |||
| """Check whether op is inplace reuse""" | |||
| if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype): | |||
| return False | |||
| _, elem_relation = cls.get_prim(op).relation_func(op, input_idx) | |||
| @@ -277,6 +293,8 @@ class Tensor: | |||
| PARA_OUTPUT = 2 | |||
| class Buddy: | |||
| """Buddy""" | |||
| def __init__(self, leader): | |||
| self.members = [leader] | |||
| @@ -328,6 +346,7 @@ class Value: | |||
| return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape))) | |||
| def get_size(self): | |||
| """Get size""" | |||
| return 1 | |||
| @@ -365,6 +384,7 @@ class Graph: | |||
| self.outputs = [] | |||
| self.stitch_info = stitch_info | |||
| self.recompute_ops = recompute_ops | |||
| self.processor = "" | |||
| def set_processor(self, processor): | |||
| """Set processor""" | |||
| @@ -498,7 +518,7 @@ class AlignShape(GraphVisitor): | |||
| """Align shape""" | |||
| def __init__(self): | |||
| super().__init__() | |||
| super(AlignShape, self).__init__() | |||
| def visit(self, op): | |||
| """Visit op node""" | |||
| @@ -517,7 +537,7 @@ class AddControlBuddy(GraphVisitor): | |||
| """Add control buddy""" | |||
| def __init__(self): | |||
| super().__init__() | |||
| super(AddControlBuddy, self).__init__() | |||
| self.buddies = {} # {op : [ctrl_op]} | |||
| def visit(self, op): | |||
| @@ -536,13 +556,15 @@ class AddControlBuddy(GraphVisitor): | |||
| def visit_graph(self, graph): | |||
| """Visit graph nodes""" | |||
| super().visit_graph(graph) | |||
| super(AddControlBuddy, self).visit_graph(graph) | |||
| for owner in self.buddies: | |||
| for op in self.buddies[owner]: | |||
| owner.add_buddy(op.output) | |||
| class GraphKernelUnsupportedException(Exception): | |||
| """"GraphKernel Unsupported Exception""" | |||
| def __init__(self, message): | |||
| super().__init__() | |||
| super(GraphKernelUnsupportedException, self).__init__() | |||
| self.message = message | |||
| @@ -26,7 +26,8 @@ namespace opt { | |||
| int64_t AxisNormalizer::NormAxis(int64_t x, size_t rank) const { return x >= 0 ? x : x + static_cast<int64_t>(rank); } | |||
| bool AxisNormalizer::IsReduce(const AnfNodePtr &node) const { | |||
| std::vector<PrimitivePtr> node_with_axis = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}; | |||
| std::vector<PrimitivePtr> node_with_axis = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin, | |||
| prim::kPrimArgMax, prim::kPrimArgMin}; | |||
| return std::any_of(node_with_axis.begin(), node_with_axis.end(), | |||
| [&node](PrimitivePtr &p) { return IsPrimitiveCNode(node, p); }); | |||
| } | |||
| @@ -68,6 +68,8 @@ std::vector<PrimitivePtr> GetClusterableOpList() { | |||
| #elif ENABLE_GPU | |||
| prim::kPrimACos, | |||
| prim::kPrimAcosh, | |||
| prim::kPrimArgMax, | |||
| prim::kPrimArgMin, | |||
| prim::kPrimAsin, | |||
| prim::kPrimAsinh, | |||
| prim::kPrimAssign, | |||
| @@ -0,0 +1,59 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| class ArgMax(nn.Cell): | |||
| def __init__(self, axis): | |||
| super(ArgMax, self).__init__() | |||
| self.arg_max = P.Argmax(axis=axis) | |||
| def construct(self, x): | |||
| return self.arg_max(x) | |||
| def get_output(x, axis, enable_graph_kernel=False): | |||
| context.set_context(enable_graph_kernel=enable_graph_kernel) | |||
| net = ArgMax(axis) | |||
| output = net(x) | |||
| return output | |||
| def test_argmax(): | |||
| x0 = Tensor(np.random.normal(0, 1, [2, 3, 4, 4]).astype(np.float32)) | |||
| axis0 = 3 | |||
| expect = get_output(x0, axis0, False) | |||
| output = get_output(x0, axis0, True) | |||
| assert np.allclose(expect.asnumpy(), output.asnumpy(), 0.0001, 0.0001) | |||
| x1 = Tensor(np.random.normal(0, 1, [2, 3, 1, 4]).astype(np.float32)) | |||
| axis1 = 2 | |||
| expect = get_output(x1, axis1, False) | |||
| output = get_output(x1, axis1, True) | |||
| assert np.allclose(expect.asnumpy(), output.asnumpy(), 0.0001, 0.0001) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_argmax_gpu(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| test_argmax() | |||