# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Patterns for describing graphs""" from mindspore.ops import Primitive from mindspore.common.tensor import Tensor from mindspore._c_expression import Pattern, OneOf_, Prim_, Call_, NoneOf_, Any, NewTensor_, NewParameter_, Imm __all__ = [ "OneOf", "Prim", "Call", "NoneOf", "Any", "NewTensor", "NewParameter", "Imm" ] class OneOf(OneOf_): r""" Express a pattern which allows a list of patterns. """ def __init__(self, patterns=None): r""" Args: patterns(Union[:class:`mindspore.graph_utils.graph_pattern`, tuple[:class:`mindspore.graph_utils.graph_pattern`], list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns, each element should be one of the exposed Pattern instance. Raises: TypeError: raise type error for invalid inputs. """ self.patterns = patterns if isinstance(patterns, Pattern): OneOf_.__init__(self, [patterns]) elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns): OneOf_.__init__(self, patterns) else: raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}") class Prim(Prim_): r""" Express a pattern of certain primitive type(s). NOTE: This pattern will match and only match the primitive value node. If matching primitive CNode is needed, please refer to CallWith pattern. """ def __init__(self, types, name=None): r""" Args: types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`], tuple[:class:`mindspore.ops.Primitive`]): Specify allowed types. If it is a string, the form could be 1) a single primitive type, e.g. 'Conv2D' 2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D' It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)] name (str): name of the pattern, optional. Default: None. Raises: TypeError: raise type error for invalid argument. """ if name is not None and not isinstance(name, str): raise TypeError(f"Expect string, got : {name}") self.name = name if isinstance(types, str): if self.name is None: self.name = types self.types = types.split('|') elif isinstance(types, Primitive): if self.name is None: self.name = types.name self.types = [types] elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types): if self.name is None: self.name = "" for prim in types: self.name += prim.name self.types = types else: raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}") Prim_.__init__(self, self.types, self.name) class Call(Call_): r""" Express a primitive CNode. """ def __init__(self, prim_pattern, inputs=None): r""" Args: prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`, :class:`mindspore.ops.Primitive`]): Primitive ValueNode in the Primitive CNode. inputs (Union[list[:class:`mindspore.graph_utils.graph_pattern`], tuple[:class:`mindspore.graph_utils.graph_pattern`]]): Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input patterns should be of right order and each element should be one of the exposed Pattern instance. Raises: TypeError: raise type error for invalid argument. """ if not isinstance(prim_pattern, (Pattern, str, Primitive)): raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}") self.prim_pattern = prim_pattern self.inputs = [] if inputs is None: pass elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs): self.inputs = inputs else: raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}") Call_.__init__(self, self.prim_pattern, self.inputs) class NoneOf(NoneOf_): r""" Express a pattern which forbids a list of patterns. NOTE: NoneOf pattern should not be the root pattern. """ def __init__(self, patterns=None): r""" Args: patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element should be one of the exposed Pattern instance. Raises: TypeError: raise type error for invalid argument. """ self.patterns = patterns if patterns is None: NoneOf_.__init__(self, ()) elif isinstance(patterns, Pattern): NoneOf_.__init__(self, [patterns]) elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns): NoneOf_.__init__(self, patterns) else: raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}") class NewTensor(NewTensor_): r""" New Tensor to be used in the target. """ def __init__(self, input_tensor): r""" Args: input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target Raises: TypeError: raise type error for invalid argument. """ self.input_tensor = input_tensor if isinstance(input_tensor, Tensor): NewTensor_.__init__(self, input_tensor) else: raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}") class NewParameter(NewParameter_): r""" New Parameter to be used in the target. """ def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False): r""" Args: para_name(str): name for the new Parameter default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter requires_grad(bool): True if the parameter requires gradient. Default: True layerwise_parallel(bool): switch for layerwise parallel mode. Default: False Raises: TypeError: raise type error for invalid argument. """ self.para_name = para_name self.default_tensor = default_tensor self.requires_grad = requires_grad self.layerwise_parallel = layerwise_parallel if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\ isinstance(layerwise_parallel, bool): NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad, self.layerwise_parallel) else: raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \ layerwise_parallel(bool), got : {para_name}, {default_tensor}, \ {requires_grad}, {layerwise_parallel}")